Skip to content

Commit 93a8895

Browse files
committed
Add locks to avoid race conditions when adding the same to grab to the DB
1 parent 1b2ff61 commit 93a8895

1 file changed

Lines changed: 47 additions & 21 deletions

File tree

plugins/grab.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import logging
12
import random
2-
33
from collections import defaultdict
4+
from threading import RLock
5+
46
from sqlalchemy import Table, Column, String
7+
from sqlalchemy.exc import SQLAlchemyError
8+
59
from cloudbot import hook
610
from cloudbot.util import database
711

@@ -18,19 +22,25 @@
1822
)
1923

2024
grab_cache = {}
25+
grab_locks = defaultdict(dict)
26+
grab_locks_lock = RLock()
27+
cache_lock = RLock()
28+
29+
logger = logging.getLogger("cloudbot")
2130

2231

2332
@hook.on_start()
2433
def load_cache(db):
2534
"""
2635
:type db: sqlalchemy.orm.Session
2736
"""
28-
grab_cache.clear()
29-
for row in db.execute(table.select().order_by(table.c.time)):
30-
name = row["name"].lower()
31-
quote = row["quote"]
32-
chan = row["chan"]
33-
grab_cache.setdefault(chan, {}).setdefault(name, []).append(quote)
37+
with cache_lock:
38+
grab_cache.clear()
39+
for row in db.execute(table.select().order_by(table.c.time)):
40+
name = row["name"].lower()
41+
quote = row["quote"]
42+
chan = row["chan"]
43+
grab_cache.setdefault(chan, {}).setdefault(name, []).append(quote)
3444

3545

3646
def two_lines(bigstring, chan):
@@ -90,27 +100,43 @@ def grab_add(nick, time, msg, chan, db, conn):
90100
load_cache(db)
91101

92102

103+
def get_latest_line(conn, chan, nick):
104+
for item in reversed(conn.history[chan]):
105+
name, timestamp, msg = item
106+
if nick.casefold() == name.casefold():
107+
return item
108+
109+
return None, None, None
110+
111+
93112
@hook.command()
94113
def grab(text, nick, chan, db, conn):
95114
"""grab <nick> grabs the last message from the
96115
specified nick and adds it to the quote database"""
97116
if text.lower() == nick.lower():
98117
return "Didn't your mother teach you not to grab yourself?"
99118

100-
for item in conn.history[chan].__reversed__():
101-
name, timestamp, msg = item
102-
if text.lower() == name.lower():
103-
# check to see if the quote has been added
104-
if check_grabs(name.lower(), msg, chan):
105-
return "I already have that quote from {} in the database".format(text)
106-
else:
107-
# the quote is new so add it to the db.
108-
grab_add(name.lower(),timestamp, msg, chan, db, conn)
109-
if check_grabs(name.lower(), msg, chan):
110-
return "the operation succeeded."
111-
else:
112-
return "the operation failed"
113-
return "I couldn't find anything from {} in recent history.".format(text)
119+
with grab_locks_lock:
120+
grab_lock = grab_locks[conn.name.casefold()].setdefault(chan.casefold(), RLock())
121+
122+
with grab_lock:
123+
name, timestamp, msg = get_latest_line(conn, chan, text)
124+
if not msg:
125+
return "I couldn't find anything from {} in recent history.".format(text)
126+
127+
if check_grabs(text.casefold(), msg, chan):
128+
return "I already have that quote from {} in the database".format(text)
129+
130+
try:
131+
grab_add(name.casefold(), timestamp, msg, chan, db, conn)
132+
except SQLAlchemyError:
133+
logger.exception("Error occurred when grabbing %s in %s", name, chan)
134+
return "Error occurred."
135+
136+
if check_grabs(name.casefold(), msg, chan):
137+
return "the operation succeeded."
138+
else:
139+
return "the operation failed"
114140

115141

116142
def format_grab(name, quote):

0 commit comments

Comments
 (0)