1- from typing import List , Optional , Tuple , Dict , Any , Union
1+ from typing import List , Optional , Tuple , Dict , Any , Union , TYPE_CHECKING
22import re
33
4- from sqlalchemy import text , bindparam , ARRAY
5- from sqlalchemy .dialects .postgresql import JSONB
4+ from sqlalchemy import text , bindparam
5+ from sqlalchemy .dialects .postgresql import JSONB , ARRAY , BIGINT
66from sqlalchemy .orm import Session
77from sqlalchemy .ext .asyncio import AsyncSession
88
99from .schema import Message , QueueMetrics
1010
11+ if TYPE_CHECKING :
12+ from sqlalchemy import TextClause
13+
1114
1215class PGMQOperation :
1316 """
@@ -124,7 +127,7 @@ def _get_list_queues_statement() -> Tuple[str, Dict[str, Any]]:
124127 @staticmethod
125128 def _get_send_statement (
126129 queue_name : str , message : dict , delay : int
127- ) -> Tuple [text , Dict [str , Any ]]:
130+ ) -> Tuple [TextClause , Dict [str , Any ]]:
128131 """Get statement and params for send."""
129132 stmt = text (
130133 "select * from pgmq.send(:queue_name, :message, :delay);"
@@ -141,7 +144,7 @@ def _get_send_statement(
141144 @staticmethod
142145 def _get_send_batch_statement (
143146 queue_name : str , messages : List [dict ], delay : int
144- ) -> Tuple [text , Dict [str , Any ]]:
147+ ) -> Tuple [TextClause , Dict [str , Any ]]:
145148 """Get statement and params for send_batch.
146149
147150 Note: This uses SQLAlchemy's bindparam with JSONB array type for proper
@@ -150,6 +153,7 @@ def _get_send_batch_statement(
150153 stmt = text (
151154 "select * from pgmq.send_batch(:queue_name, :messages, :delay);"
152155 ).bindparams (bindparam ("messages" , type_ = ARRAY (JSONB )))
156+
153157 return (
154158 stmt ,
155159 {
@@ -211,20 +215,28 @@ def _get_pop_statement(queue_name: str) -> Tuple[str, Dict[str, Any]]:
211215 @staticmethod
212216 def _get_delete_statement (
213217 queue_name : str , msg_id : int
214- ) -> Tuple [str , Dict [str , Any ]]:
218+ ) -> Tuple [TextClause , Dict [str , Any ]]:
215219 """Get statement and params for delete."""
216- return "select pgmq.delete(:queue_name, :msg_id) as deleted;" , {
220+ stmt = text ("select pgmq.delete(:queue_name, :msg_id) as deleted;" ).bindparams (
221+ bindparam ("msg_id" , type_ = BIGINT )
222+ )
223+
224+ return stmt , {
217225 "queue_name" : queue_name ,
218226 "msg_id" : msg_id ,
219227 }
220228
221229 @staticmethod
222230 def _get_delete_batch_statement (
223231 queue_name : str , msg_ids : List [int ]
224- ) -> Tuple [str , Dict [str , Any ]]:
232+ ) -> Tuple [TextClause , Dict [str , Any ]]:
225233 """Get statement and params for delete_batch."""
234+ stmt = text (
235+ "select * from pgmq.delete_batch(:queue_name, :msg_ids);"
236+ ).bindparams (bindparam ("msg_ids" , type_ = ARRAY (BIGINT )))
237+
226238 return (
227- "select t.msg_id from unnest(CAST(:msg_ids AS bigint[])) as t(msg_id) where pgmq.delete(:queue_name, t.msg_id);" ,
239+ stmt ,
228240 {"queue_name" : queue_name , "msg_ids" : msg_ids },
229241 )
230242
@@ -1045,7 +1057,7 @@ def delete(
10451057 True if the message was deleted successfully.
10461058 """
10471059 stmt , params = PGMQOperation ._get_delete_statement (queue_name , msg_id )
1048- row = session .execute (text ( stmt ) , params ).fetchone ()
1060+ row = session .execute (stmt , params ).fetchone ()
10491061 if commit :
10501062 session .commit ()
10511063 return row [0 ]
@@ -1070,7 +1082,7 @@ async def delete_async(
10701082 True if the message was deleted successfully.
10711083 """
10721084 stmt , params = PGMQOperation ._get_delete_statement (queue_name , msg_id )
1073- row = (await session .execute (text ( stmt ) , params )).fetchone ()
1085+ row = (await session .execute (stmt , params )).fetchone ()
10741086 if commit :
10751087 await session .commit ()
10761088 return row [0 ]
@@ -1095,7 +1107,7 @@ def delete_batch(
10951107 List of message IDs that were successfully deleted.
10961108 """
10971109 stmt , params = PGMQOperation ._get_delete_batch_statement (queue_name , msg_ids )
1098- rows = session .execute (text ( stmt ) , params ).fetchall ()
1110+ rows = session .execute (stmt , params ).fetchall ()
10991111 if commit :
11001112 session .commit ()
11011113 return [row [0 ] for row in rows ]
@@ -1120,7 +1132,7 @@ async def delete_batch_async(
11201132 List of message IDs that were successfully deleted.
11211133 """
11221134 stmt , params = PGMQOperation ._get_delete_batch_statement (queue_name , msg_ids )
1123- rows = (await session .execute (text ( stmt ) , params )).fetchall ()
1135+ rows = (await session .execute (stmt , params )).fetchall ()
11241136 if commit :
11251137 await session .commit ()
11261138 return [row [0 ] for row in rows ]
0 commit comments