Skip to content

Commit a2e6b24

Browse files
committed
Add IRCv3 cap support
1 parent 5d74e48 commit a2e6b24

5 files changed

Lines changed: 190 additions & 7 deletions

File tree

cloudbot/clients/irc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def connect(self):
122122
self._transport, self._protocol = yield from self.loop.create_connection(
123123
lambda: _IrcProtocol(self), host=self.server, port=self.port, ssl=self.ssl_context, **optional_params)
124124

125-
# send the password, nick, and user
125+
# send the cap ls, password, nick, and user
126+
self.send("CAP LS 302")
126127
self.set_pass(self.config["connection"].get("password"))
127128
self.set_nick(self.nick)
128129
self.cmd("USER", self.config.get('user', 'cloudbot'), "3", "*",

cloudbot/event.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
2+
import concurrent.futures
23
import enum
34
import logging
4-
import concurrent.futures
55

66
logger = logging.getLogger("cloudbot")
77

@@ -383,3 +383,10 @@ def __init__(self, *, bot=None, hook, match, conn=None, base_event=None, event_t
383383
content_raw=content_raw, target=target, channel=channel, nick=nick, user=user, host=host, mask=mask,
384384
irc_raw=irc_raw, irc_prefix=irc_prefix, irc_command=irc_command, irc_paramlist=irc_paramlist)
385385
self.match = match
386+
387+
388+
class CapEvent(Event):
389+
def __init__(self, *args, cap, cap_param=None, **kwargs):
390+
super().__init__(*args, **kwargs)
391+
self.cap = cap
392+
self.cap_param = cap_param

cloudbot/hook.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import collections
12
import inspect
23
import re
3-
import collections
44

55
from cloudbot.event import EventType
66

@@ -177,6 +177,16 @@ def add_hook(self, trigger_param, kwargs):
177177
self.types.update(trigger_param)
178178

179179

180+
class _CapHook(_Hook):
181+
def __init__(self, func, _type):
182+
super().__init__(func, "on_cap_{}".format(_type))
183+
self.caps = set()
184+
185+
def add_hook(self, caps, kwargs):
186+
self._add_hook(kwargs)
187+
self.caps.update(caps)
188+
189+
180190
def _add_hook(func, hook):
181191
if not hasattr(func, "_cloudbot_hook"):
182192
func._cloudbot_hook = {}
@@ -357,3 +367,37 @@ def _on_stop_hook(func):
357367
return lambda func: _on_stop_hook(func)
358368

359369
on_unload = on_stop
370+
371+
372+
def on_cap_available(*caps, **kwargs):
373+
"""External on_cap_available decorator. Must be used as a function that returns a decorator
374+
375+
This hook will fire for each capability in a `CAP LS` response from the server
376+
"""
377+
378+
def _on_cap_available_hook(func):
379+
hook = _get_hook(func, "on_cap_available")
380+
if hook is None:
381+
hook = _CapHook(func, "available")
382+
_add_hook(func, hook)
383+
hook.add_hook(caps, kwargs)
384+
return func
385+
386+
return _on_cap_available_hook
387+
388+
389+
def on_cap_ack(*caps, **kwargs):
390+
"""External on_cap_ack decorator. Must be used as a function that returns a decorator
391+
392+
This hook will fire for each capability that is acknowledged from the server with `CAP ACK`
393+
"""
394+
395+
def _on_cap_ack_hook(func):
396+
hook = _get_hook(func, "on_cap_ack")
397+
if hook is None:
398+
hook = _CapHook(func, "ack")
399+
_add_hook(func, hook)
400+
hook.add_hook(caps, kwargs)
401+
return func
402+
403+
return _on_cap_ack_hook

cloudbot/plugin.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
import re
8+
from collections import defaultdict
89

910
import sqlalchemy
1011

@@ -18,7 +19,7 @@ def find_hooks(parent, module):
1819
"""
1920
:type parent: Plugin
2021
:type module: object
21-
:rtype: (list[CommandHook], list[RegexHook], list[RawHook], list[SieveHook], List[EventHook], List[PeriodicHook], list[OnStartHook], List[OnStopHook])
22+
:rtype: (list[CommandHook], list[RegexHook], list[RawHook], list[SieveHook], List[EventHook], List[PeriodicHook], list[OnStartHook], List[OnStopHook], list[OnCapAckHook], list[OnCapAvailableHook])
2223
"""
2324
# set the loaded flag
2425
module._cloudbot_loaded = True
@@ -30,8 +31,11 @@ def find_hooks(parent, module):
3031
periodic = []
3132
on_start = []
3233
on_stop = []
34+
on_cap_ack = []
35+
on_cap_available = []
3336
type_lists = {"command": command, "regex": regex, "irc_raw": raw, "sieve": sieve, "event": event,
34-
"periodic": periodic, "on_start": on_start, "on_stop": on_stop}
37+
"periodic": periodic, "on_start": on_start, "on_stop": on_stop, "on_cap_ack": on_cap_ack,
38+
"on_cap_available": on_cap_available}
3539
for name, func in module.__dict__.items():
3640
if hasattr(func, "_cloudbot_hook"):
3741
# if it has cloudbot hook
@@ -43,7 +47,7 @@ def find_hooks(parent, module):
4347
# delete the hook to free memory
4448
del func._cloudbot_hook
4549

46-
return command, regex, raw, sieve, event, periodic, on_start, on_stop
50+
return command, regex, raw, sieve, event, periodic, on_start, on_stop, on_cap_ack, on_cap_available
4751

4852

4953
def find_tables(code):
@@ -98,6 +102,7 @@ def __init__(self, bot):
98102
self.event_type_hooks = {}
99103
self.regex_hooks = []
100104
self.sieves = []
105+
self.cap_hooks = {"on_available": defaultdict(list), "on_ack": defaultdict(list)}
101106
self._hook_waiting_queues = {}
102107

103108
@asyncio.coroutine
@@ -179,6 +184,16 @@ def load_plugin(self, path):
179184

180185
self.plugins[plugin.file_name] = plugin
181186

187+
for on_cap_available_hook in plugin.on_cap_available:
188+
for cap in on_cap_available_hook.caps:
189+
self.cap_hooks["on_available"][cap.casefold()].append(on_cap_available_hook)
190+
self._log_hook(on_cap_available_hook)
191+
192+
for on_cap_ack_hook in plugin.on_cap_ack:
193+
for cap in on_cap_ack_hook.caps:
194+
self.cap_hooks["on_ack"][cap.casefold()].append(on_cap_ack_hook)
195+
self._log_hook(on_cap_ack_hook)
196+
182197
for periodic_hook in plugin.periodic:
183198
task = asyncio.async(self._start_periodic(periodic_hook))
184199
plugin.tasks.append(task)
@@ -259,6 +274,22 @@ def unload_plugin(self, path):
259274
for task in plugin.tasks:
260275
task.cancel()
261276

277+
for on_cap_available_hook in plugin.on_cap_available:
278+
available_hooks = self.cap_hooks["on_available"]
279+
for cap in on_cap_available_hook.caps:
280+
cap_cf = cap.casefold()
281+
available_hooks[cap_cf].remove(on_cap_available_hook)
282+
if not available_hooks[cap_cf]:
283+
del available_hooks[cap_cf]
284+
285+
for on_cap_ack in plugin.on_cap_ack:
286+
ack_hooks = self.cap_hooks["on_ack"]
287+
for cap in on_cap_ack.caps:
288+
cap_cf = cap.casefold()
289+
ack_hooks[cap_cf].remove(on_cap_ack)
290+
if not ack_hooks[cap_cf]:
291+
del ack_hooks[cap_cf]
292+
262293
# unregister commands
263294
for command_hook in plugin.commands:
264295
for alias in command_hook.aliases:
@@ -521,7 +552,8 @@ def __init__(self, filepath, filename, title, code):
521552
self.file_path = filepath
522553
self.file_name = filename
523554
self.title = title
524-
self.commands, self.regexes, self.raw_hooks, self.sieves, self.events, self.periodic, self.run_on_start, self.run_on_stop = find_hooks(self, code)
555+
self.commands, self.regexes, self.raw_hooks, self.sieves, self.events, self.periodic, self.run_on_start, self.run_on_stop, self.on_cap_ack, self.on_cap_available = find_hooks(
556+
self, code)
525557
# we need to find tables for each plugin so that they can be unloaded from the global metadata when the
526558
# plugin is reloaded
527559
self.tables = find_tables(code)
@@ -776,6 +808,28 @@ def __str__(self):
776808
return "on_stop {} from {}".format(self.function_name, self.plugin.file_name)
777809

778810

811+
class CapHook(Hook):
812+
def __init__(self, _type, plugin, base_hook):
813+
self.caps = base_hook.caps
814+
super().__init__("on_cap_{}".format(_type), plugin, base_hook)
815+
816+
def __repr__(self):
817+
return "{name}[{caps} {base!r}]".format(name=self.type, caps=self.caps, base=super())
818+
819+
def __str__(self):
820+
return "{name} {func} from {file}".format(name=self.type, func=self.function_name, file=self.plugin.file_name)
821+
822+
823+
class OnCapAvaliableHook(CapHook):
824+
def __init__(self, plugin, base_hook):
825+
super().__init__("available", plugin, base_hook)
826+
827+
828+
class OnCapAckHook(CapHook):
829+
def __init__(self, plugin, base_hook):
830+
super().__init__("ack", plugin, base_hook)
831+
832+
779833
_hook_name_to_plugin = {
780834
"command": CommandHook,
781835
"regex": RegexHook,
@@ -785,4 +839,6 @@ def __str__(self):
785839
"periodic": PeriodicHook,
786840
"on_start": OnStartHook,
787841
"on_stop": OnStopHook,
842+
"on_cap_available": OnCapAvaliableHook,
843+
"on_cap_ack": OnCapAckHook,
788844
}

plugins/cap.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import asyncio
2+
import logging
3+
from functools import partial
4+
5+
from cloudbot import hook
6+
from cloudbot.event import CapEvent
7+
8+
logger = logging.getLogger("cloudbot")
9+
10+
11+
@asyncio.coroutine
12+
def handle_available_caps(conn, caplist, event, irc_paramlist, bot):
13+
available_caps = conn.memory.setdefault("available_caps", set())
14+
caps = [tuple(cap.split('=', 1)) for cap in caplist]
15+
available_caps.update(caps)
16+
cap_queue = conn.memory.setdefault("cap_queue", {})
17+
for cap, *param in caps:
18+
cap_event = partial(CapEvent, base_event=event, cap=cap, cap_param=param[0] if param else None)
19+
tasks = [
20+
bot.plugin_manager.launch(_hook, cap_event(hook=_hook))
21+
for _hook in bot.plugin_manager.cap_hooks["on_available"][cap.casefold()]
22+
]
23+
results = yield from asyncio.gather(*tasks)
24+
if any(results):
25+
cap_queue[cap.casefold()] = conn.loop.create_future()
26+
conn.cmd("CAP", "REQ", cap)
27+
28+
if irc_paramlist[2] != '+':
29+
yield from asyncio.gather(*cap_queue.values())
30+
cap_queue.clear()
31+
conn.send("CAP END")
32+
33+
34+
@asyncio.coroutine
35+
@hook.irc_raw("CAP")
36+
def on_cap(irc_paramlist, conn, bot, event):
37+
caplist = []
38+
if len(irc_paramlist) > 2:
39+
capstr = irc_paramlist[-1].strip()
40+
if capstr[0] == ':':
41+
capstr = capstr[1:]
42+
43+
caplist = capstr.split()
44+
subcmd = irc_paramlist[1].upper()
45+
if subcmd == "LS":
46+
yield from handle_available_caps(conn, caplist, event, irc_paramlist, bot)
47+
48+
elif subcmd in ('ACK', 'NAK'):
49+
enabled = subcmd == 'ACK'
50+
server_caps = conn.memory.setdefault('server_caps', {})
51+
cap_queue = conn.memory.get("cap_queue", {})
52+
caps = [cap.casefold() for cap in caplist]
53+
for cap in caps:
54+
server_caps[cap] = enabled
55+
if enabled:
56+
cap_event = partial(CapEvent, base_event=event, cap=cap)
57+
tasks = [
58+
bot.plugin_manager.launch(_hook, cap_event(hook=_hook))
59+
for _hook in bot.plugin_manager.cap_hooks["on_ack"][cap]
60+
]
61+
yield from asyncio.gather(*tasks)
62+
63+
if cap in cap_queue:
64+
cap_queue[cap].set_result(enabled)
65+
66+
elif subcmd == 'LIST':
67+
logger.info("Enabled Capabilities: %s", irc_paramlist[-1])
68+
elif subcmd == 'NEW':
69+
logger.info("New capabilities advertised: %s", irc_paramlist[-1])
70+
yield from handle_available_caps(conn, caplist, event, irc_paramlist, bot)
71+
elif subcmd == 'DEL':
72+
logger.info("Capabilities removed by server: %s", irc_paramlist[-1])
73+
server_caps = conn.memory.setdefault('server_caps', {})
74+
for cap in caplist:
75+
server_caps[cap] = False

0 commit comments

Comments
 (0)