|
1 | 1 | import asyncio |
| 2 | +from collections import defaultdict |
| 3 | +from threading import RLock |
2 | 4 |
|
3 | 5 | from sqlalchemy import PrimaryKeyConstraint, Column, String, Table, and_ |
4 | | -from sqlalchemy.exc import IntegrityError |
5 | 6 |
|
6 | 7 | from cloudbot import hook |
7 | 8 | from cloudbot.util import database |
|
14 | 15 | PrimaryKeyConstraint('conn', 'chan') |
15 | 16 | ) |
16 | 17 |
|
| 18 | +chan_cache = defaultdict(set) |
| 19 | +db_lock = RLock() |
| 20 | + |
17 | 21 |
|
18 | 22 | def get_channels(db, conn): |
19 | 23 | return db.execute(table.select().where(table.c.conn == conn.name.casefold())).fetchall() |
20 | 24 |
|
21 | 25 |
|
| 26 | +@hook.on_start |
| 27 | +def load_cache(db): |
| 28 | + with db_lock: |
| 29 | + chan_cache.clear() |
| 30 | + for row in db.execute(table.select()): |
| 31 | + chan_cache[row['conn']].add(row['chan']) |
| 32 | + |
| 33 | + |
22 | 34 | @asyncio.coroutine |
23 | 35 | @hook.irc_raw('376') |
24 | | -def do_joins(db, conn, async_call): |
25 | | - chans = yield from async_call(get_channels, db, conn) |
| 36 | +def do_joins(conn): |
26 | 37 | join_throttle = conn.config.get("join_throttle", 0.4) |
27 | | - for chan in chans: |
28 | | - conn.join(chan[1]) |
| 38 | + for chan in chan_cache[conn.name]: |
| 39 | + conn.join(chan) |
29 | 40 | yield from asyncio.sleep(join_throttle) |
30 | 41 |
|
31 | 42 |
|
32 | 43 | @hook.irc_raw('JOIN', singlethread=True) |
33 | 44 | def add_chan(db, conn, chan, nick): |
34 | | - if nick.casefold() == conn.nick.casefold(): |
35 | | - try: |
| 45 | + chans = chan_cache[conn.name] |
| 46 | + chan = chan.casefold() |
| 47 | + if nick.casefold() == conn.nick.casefold() and chan not in chans: |
| 48 | + with db_lock: |
36 | 49 | db.execute(table.insert().values(conn=conn.name.casefold(), chan=chan.casefold())) |
37 | 50 | db.commit() |
38 | | - except IntegrityError: |
39 | | - pass |
| 51 | + |
| 52 | + load_cache(db) |
40 | 53 |
|
41 | 54 |
|
42 | 55 | @hook.irc_raw('PART', singlethread=True) |
43 | 56 | def on_part(db, conn, chan, nick): |
44 | 57 | if nick.casefold() == conn.nick.casefold(): |
45 | | - db.execute(table.delete().where(and_(table.c.conn == conn.name.casefold(), table.c.chan == chan.casefold()))) |
46 | | - db.commit() |
| 58 | + with db_lock: |
| 59 | + db.execute( |
| 60 | + table.delete().where(and_(table.c.conn == conn.name.casefold(), table.c.chan == chan.casefold()))) |
| 61 | + db.commit() |
| 62 | + |
| 63 | + load_cache(db) |
47 | 64 |
|
48 | 65 |
|
49 | 66 | @hook.irc_raw('KICK', singlethread=True) |
50 | 67 | def on_kick(db, conn, chan, target): |
51 | | - if target.casefold() == conn.nick.casefold(): |
52 | | - db.execute(table.delete().where(and_(table.c.conn == conn.name.casefold(), table.c.chan == chan.casefold()))) |
53 | | - db.commit() |
| 68 | + on_part(db, conn, chan, target) |
0 commit comments