Skip to content

Commit df20827

Browse files
committed
Fix ambiguous error manually
1 parent 44eb58d commit df20827

1 file changed

Lines changed: 25 additions & 13 deletions

File tree

pgmq_sqlalchemy/operation.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from typing import List, Optional, Tuple, Dict, Any, Union
1+
from typing import List, Optional, Tuple, Dict, Any, Union, TYPE_CHECKING
22
import re
33

4-
from sqlalchemy import text, bindparam, ARRAY
5-
from sqlalchemy.dialects.postgresql import JSONB
4+
from sqlalchemy import text, bindparam
5+
from sqlalchemy.dialects.postgresql import JSONB, ARRAY, BIGINT
66
from sqlalchemy.orm import Session
77
from sqlalchemy.ext.asyncio import AsyncSession
88

99
from .schema import Message, QueueMetrics
1010

11+
if TYPE_CHECKING:
12+
from sqlalchemy import TextClause
13+
1114

1215
class PGMQOperation:
1316
"""
@@ -124,7 +127,7 @@ def _get_list_queues_statement() -> Tuple[str, Dict[str, Any]]:
124127
@staticmethod
125128
def _get_send_statement(
126129
queue_name: str, message: dict, delay: int
127-
) -> Tuple[text, Dict[str, Any]]:
130+
) -> Tuple[TextClause, Dict[str, Any]]:
128131
"""Get statement and params for send."""
129132
stmt = text(
130133
"select * from pgmq.send(:queue_name, :message, :delay);"
@@ -141,7 +144,7 @@ def _get_send_statement(
141144
@staticmethod
142145
def _get_send_batch_statement(
143146
queue_name: str, messages: List[dict], delay: int
144-
) -> Tuple[text, Dict[str, Any]]:
147+
) -> Tuple[TextClause, Dict[str, Any]]:
145148
"""Get statement and params for send_batch.
146149
147150
Note: This uses SQLAlchemy's bindparam with JSONB array type for proper
@@ -150,6 +153,7 @@ def _get_send_batch_statement(
150153
stmt = text(
151154
"select * from pgmq.send_batch(:queue_name, :messages, :delay);"
152155
).bindparams(bindparam("messages", type_=ARRAY(JSONB)))
156+
153157
return (
154158
stmt,
155159
{
@@ -211,20 +215,28 @@ def _get_pop_statement(queue_name: str) -> Tuple[str, Dict[str, Any]]:
211215
@staticmethod
212216
def _get_delete_statement(
213217
queue_name: str, msg_id: int
214-
) -> Tuple[str, Dict[str, Any]]:
218+
) -> Tuple[TextClause, Dict[str, Any]]:
215219
"""Get statement and params for delete."""
216-
return "select pgmq.delete(:queue_name, :msg_id) as deleted;", {
220+
stmt = text("select pgmq.delete(:queue_name, :msg_id) as deleted;").bindparams(
221+
bindparam("msg_id", type_=BIGINT)
222+
)
223+
224+
return stmt, {
217225
"queue_name": queue_name,
218226
"msg_id": msg_id,
219227
}
220228

221229
@staticmethod
222230
def _get_delete_batch_statement(
223231
queue_name: str, msg_ids: List[int]
224-
) -> Tuple[str, Dict[str, Any]]:
232+
) -> Tuple[TextClause, Dict[str, Any]]:
225233
"""Get statement and params for delete_batch."""
234+
stmt = text(
235+
"select * from pgmq.delete_batch(:queue_name, :msg_ids);"
236+
).bindparams(bindparam("msg_ids", type_=ARRAY(BIGINT)))
237+
226238
return (
227-
"select t.msg_id from unnest(CAST(:msg_ids AS bigint[])) as t(msg_id) where pgmq.delete(:queue_name, t.msg_id);",
239+
stmt,
228240
{"queue_name": queue_name, "msg_ids": msg_ids},
229241
)
230242

@@ -1045,7 +1057,7 @@ def delete(
10451057
True if the message was deleted successfully.
10461058
"""
10471059
stmt, params = PGMQOperation._get_delete_statement(queue_name, msg_id)
1048-
row = session.execute(text(stmt), params).fetchone()
1060+
row = session.execute(stmt, params).fetchone()
10491061
if commit:
10501062
session.commit()
10511063
return row[0]
@@ -1070,7 +1082,7 @@ async def delete_async(
10701082
True if the message was deleted successfully.
10711083
"""
10721084
stmt, params = PGMQOperation._get_delete_statement(queue_name, msg_id)
1073-
row = (await session.execute(text(stmt), params)).fetchone()
1085+
row = (await session.execute(stmt, params)).fetchone()
10741086
if commit:
10751087
await session.commit()
10761088
return row[0]
@@ -1095,7 +1107,7 @@ def delete_batch(
10951107
List of message IDs that were successfully deleted.
10961108
"""
10971109
stmt, params = PGMQOperation._get_delete_batch_statement(queue_name, msg_ids)
1098-
rows = session.execute(text(stmt), params).fetchall()
1110+
rows = session.execute(stmt, params).fetchall()
10991111
if commit:
11001112
session.commit()
11011113
return [row[0] for row in rows]
@@ -1120,7 +1132,7 @@ async def delete_batch_async(
11201132
List of message IDs that were successfully deleted.
11211133
"""
11221134
stmt, params = PGMQOperation._get_delete_batch_statement(queue_name, msg_ids)
1123-
rows = (await session.execute(text(stmt), params)).fetchall()
1135+
rows = (await session.execute(stmt, params)).fetchall()
11241136
if commit:
11251137
await session.commit()
11261138
return [row[0] for row in rows]

0 commit comments

Comments
 (0)