|
| 1 | +import logging |
1 | 2 | import random |
2 | | - |
3 | 3 | from collections import defaultdict |
| 4 | +from threading import RLock |
| 5 | + |
4 | 6 | from sqlalchemy import Table, Column, String |
| 7 | +from sqlalchemy.exc import SQLAlchemyError |
| 8 | + |
5 | 9 | from cloudbot import hook |
6 | 10 | from cloudbot.util import database |
7 | 11 |
|
|
18 | 22 | ) |
19 | 23 |
|
20 | 24 | grab_cache = {} |
| 25 | +grab_locks = defaultdict(dict) |
| 26 | +grab_locks_lock = RLock() |
| 27 | +cache_lock = RLock() |
| 28 | + |
| 29 | +logger = logging.getLogger("cloudbot") |
21 | 30 |
|
22 | 31 |
|
23 | 32 | @hook.on_start() |
24 | 33 | def load_cache(db): |
25 | 34 | """ |
26 | 35 | :type db: sqlalchemy.orm.Session |
27 | 36 | """ |
28 | | - grab_cache.clear() |
29 | | - for row in db.execute(table.select().order_by(table.c.time)): |
30 | | - name = row["name"].lower() |
31 | | - quote = row["quote"] |
32 | | - chan = row["chan"] |
33 | | - grab_cache.setdefault(chan, {}).setdefault(name, []).append(quote) |
| 37 | + with cache_lock: |
| 38 | + grab_cache.clear() |
| 39 | + for row in db.execute(table.select().order_by(table.c.time)): |
| 40 | + name = row["name"].lower() |
| 41 | + quote = row["quote"] |
| 42 | + chan = row["chan"] |
| 43 | + grab_cache.setdefault(chan, {}).setdefault(name, []).append(quote) |
34 | 44 |
|
35 | 45 |
|
36 | 46 | def two_lines(bigstring, chan): |
@@ -90,27 +100,42 @@ def grab_add(nick, time, msg, chan, db, conn): |
90 | 100 | load_cache(db) |
91 | 101 |
|
92 | 102 |
|
| 103 | +def get_latest_line(conn, chan, nick): |
| 104 | + for name, timestamp, msg in reversed(conn.history[chan]): |
| 105 | + if nick.casefold() == name.casefold(): |
| 106 | + return name, timestamp, msg |
| 107 | + |
| 108 | + return None, None, None |
| 109 | + |
| 110 | + |
93 | 111 | @hook.command() |
94 | 112 | def grab(text, nick, chan, db, conn): |
95 | 113 | """grab <nick> grabs the last message from the |
96 | 114 | specified nick and adds it to the quote database""" |
97 | 115 | if text.lower() == nick.lower(): |
98 | 116 | return "Didn't your mother teach you not to grab yourself?" |
99 | 117 |
|
100 | | - for item in conn.history[chan].__reversed__(): |
101 | | - name, timestamp, msg = item |
102 | | - if text.lower() == name.lower(): |
103 | | - # check to see if the quote has been added |
104 | | - if check_grabs(name.lower(), msg, chan): |
105 | | - return "I already have that quote from {} in the database".format(text) |
106 | | - else: |
107 | | - # the quote is new so add it to the db. |
108 | | - grab_add(name.lower(),timestamp, msg, chan, db, conn) |
109 | | - if check_grabs(name.lower(), msg, chan): |
110 | | - return "the operation succeeded." |
111 | | - else: |
112 | | - return "the operation failed" |
113 | | - return "I couldn't find anything from {} in recent history.".format(text) |
| 118 | + with grab_locks_lock: |
| 119 | + grab_lock = grab_locks[conn.name.casefold()].setdefault(chan.casefold(), RLock()) |
| 120 | + |
| 121 | + with grab_lock: |
| 122 | + name, timestamp, msg = get_latest_line(conn, chan, text) |
| 123 | + if not msg: |
| 124 | + return "I couldn't find anything from {} in recent history.".format(text) |
| 125 | + |
| 126 | + if check_grabs(text.casefold(), msg, chan): |
| 127 | + return "I already have that quote from {} in the database".format(text) |
| 128 | + |
| 129 | + try: |
| 130 | + grab_add(name.casefold(), timestamp, msg, chan, db, conn) |
| 131 | + except SQLAlchemyError: |
| 132 | + logger.exception("Error occurred when grabbing %s in %s", name, chan) |
| 133 | + return "Error occurred." |
| 134 | + |
| 135 | + if check_grabs(name.casefold(), msg, chan): |
| 136 | + return "the operation succeeded." |
| 137 | + else: |
| 138 | + return "the operation failed" |
114 | 139 |
|
115 | 140 |
|
116 | 141 | def format_grab(name, quote): |
|
0 commit comments