Skip to content

Commit 6b2b845

Browse files
committed
Fix bindparams usage
Fix typ hint
1 parent acfae2f commit 6b2b845

1 file changed

Lines changed: 34 additions & 15 deletions

File tree

pgmq_sqlalchemy/operation.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33

44
from sqlalchemy import text, bindparam
5-
from sqlalchemy.dialects.postgresql import JSONB, ARRAY, BIGINT
5+
from sqlalchemy.dialects.postgresql import JSONB, ARRAY, BIGINT, TEXT
66
from sqlalchemy.orm import Session
77
from sqlalchemy.ext.asyncio import AsyncSession
88

@@ -131,7 +131,9 @@ def _get_send_statement(
131131
"""Get statement and params for send."""
132132
stmt = text(
133133
"select * from pgmq.send(:queue_name, :message, :delay);"
134-
).bindparams(bindparam("message", type_=JSONB))
134+
).bindparams(
135+
bindparam("queue_name", type_=TEXT), bindparam("message", type_=JSONB)
136+
)
135137
return (
136138
stmt,
137139
{
@@ -152,7 +154,10 @@ def _get_send_batch_statement(
152154
"""
153155
stmt = text(
154156
"select * from pgmq.send_batch(:queue_name, :messages, :delay);"
155-
).bindparams(bindparam("messages", type_=ARRAY(JSONB)))
157+
).bindparams(
158+
bindparam("queue_name", type_=TEXT),
159+
bindparam("messages", type_=ARRAY(JSONB)),
160+
)
156161

157162
return (
158163
stmt,
@@ -218,7 +223,7 @@ def _get_delete_statement(
218223
) -> Tuple["TextClause", Dict[str, Any]]:
219224
"""Get statement and params for delete."""
220225
stmt = text("select pgmq.delete(:queue_name, :msg_id) as deleted;").bindparams(
221-
bindparam("msg_id", type_=BIGINT)
226+
bindparam("queue_name", type_=TEXT), bindparam("msg_id", type_=BIGINT)
222227
)
223228

224229
return stmt, {
@@ -233,7 +238,10 @@ def _get_delete_batch_statement(
233238
"""Get statement and params for delete_batch."""
234239
stmt = text(
235240
"select * from pgmq.delete_batch(:queue_name, :msg_ids);"
236-
).bindparams(bindparam("msg_ids", type_=ARRAY(BIGINT)))
241+
).bindparams(
242+
bindparam("queue_name", type_=TEXT),
243+
bindparam("msg_ids", type_=ARRAY(BIGINT)),
244+
)
237245

238246
return (
239247
stmt,
@@ -243,22 +251,33 @@ def _get_delete_batch_statement(
243251
@staticmethod
244252
def _get_archive_statement(
245253
queue_name: str, msg_id: int
246-
) -> Tuple[str, Dict[str, Any]]:
254+
) -> Tuple["TextClause", Dict[str, Any]]:
247255
"""Get statement and params for archive."""
248-
return "select pgmq.archive(:queue_name, :msg_id) as archived;", {
256+
stmt = text(
257+
"select pgmq.archive(:queue_name, :msg_id) as archived;"
258+
).bindparams(
259+
bindparam("queue_name", type_=TEXT), bindparam("msg_id", type_=BIGINT)
260+
)
261+
return stmt, {
249262
"queue_name": queue_name,
250263
"msg_id": msg_id,
251264
}
252265

253266
@staticmethod
254267
def _get_archive_batch_statement(
255268
queue_name: str, msg_ids: List[int]
256-
) -> Tuple[str, Dict[str, Any]]:
269+
) -> Tuple["TextClause", Dict[str, Any]]:
257270
"""Get statement and params for archive_batch."""
258-
return (
259-
"select t.msg_id from unnest(CAST(:msg_ids AS bigint[])) as t(msg_id) where pgmq.archive(:queue_name, t.msg_id);",
260-
{"queue_name": queue_name, "msg_ids": msg_ids},
271+
stmt = text(
272+
"select t.msg_id from unnest(CAST(:msg_ids AS bigint[])) as t(msg_id) where pgmq.archive(:queue_name, t.msg_id);"
273+
).bindparams(
274+
bindparam("queue_name", type_=TEXT),
275+
bindparam("msg_ids", type_=ARRAY(BIGINT)),
261276
)
277+
return stmt, {
278+
"queue_name": queue_name,
279+
"msg_ids": msg_ids,
280+
}
262281

263282
@staticmethod
264283
def _get_purge_statement(queue_name: str) -> Tuple[str, Dict[str, Any]]:
@@ -1157,7 +1176,7 @@ def archive(
11571176
True if the message was archived successfully.
11581177
"""
11591178
stmt, params = PGMQOperation._get_archive_statement(queue_name, msg_id)
1160-
row = session.execute(text(stmt), params).fetchone()
1179+
row = session.execute(stmt, params).fetchone()
11611180
if commit:
11621181
session.commit()
11631182
return row[0]
@@ -1182,7 +1201,7 @@ async def archive_async(
11821201
True if the message was archived successfully.
11831202
"""
11841203
stmt, params = PGMQOperation._get_archive_statement(queue_name, msg_id)
1185-
row = (await session.execute(text(stmt), params)).fetchone()
1204+
row = (await session.execute(stmt, params)).fetchone()
11861205
if commit:
11871206
await session.commit()
11881207
return row[0]
@@ -1207,7 +1226,7 @@ def archive_batch(
12071226
List of message IDs that were successfully archived.
12081227
"""
12091228
stmt, params = PGMQOperation._get_archive_batch_statement(queue_name, msg_ids)
1210-
rows = session.execute(text(stmt), params).fetchall()
1229+
rows = session.execute(stmt, params).fetchall()
12111230
if commit:
12121231
session.commit()
12131232
return [row[0] for row in rows]
@@ -1232,7 +1251,7 @@ async def archive_batch_async(
12321251
List of message IDs that were successfully archived.
12331252
"""
12341253
stmt, params = PGMQOperation._get_archive_batch_statement(queue_name, msg_ids)
1235-
rows = (await session.execute(text(stmt), params)).fetchall()
1254+
rows = (await session.execute(stmt, params)).fetchall()
12361255
if commit:
12371256
await session.commit()
12381257
return [row[0] for row in rows]

0 commit comments

Comments
 (0)