Skip to content

Commit ac23547

Browse files
Copilotjason810496
andauthored
Fix JSONB parameter binding to prevent psycopg2 dict adaptation errors (#25)
* Initial plan * Fix JSONB handling with bindparams for send and send_batch operations Co-authored-by: jason810496 <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: jason810496 <[email protected]>
1 parent 698e8ec commit ac23547

1 file changed

Lines changed: 22 additions & 18 deletions

File tree

pgmq_sqlalchemy/operation.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import List, Optional, Tuple, Dict, Any, Union
22
import re
33

4-
from sqlalchemy import text
4+
from sqlalchemy import text, bindparam, ARRAY
5+
from sqlalchemy.dialects.postgresql import JSONB
56
from sqlalchemy.orm import Session
67
from sqlalchemy.ext.asyncio import AsyncSession
78

@@ -114,10 +115,15 @@ def _get_list_queues_statement() -> Tuple[str, Dict[str, Any]]:
114115
@staticmethod
115116
def _get_send_statement(
116117
queue_name: str, message: dict, delay: int
117-
) -> Tuple[str, Dict[str, Any]]:
118+
) -> Tuple[text, Dict[str, Any]]:
118119
"""Get statement and params for send."""
120+
stmt = text(
121+
"select * from pgmq.send(:queue_name, :message, :delay);"
122+
).bindparams(
123+
bindparam("message", type_=JSONB)
124+
)
119125
return (
120-
"select * from pgmq.send(:queue_name, CAST(:message AS jsonb), :delay);",
126+
stmt,
121127
{
122128
"queue_name": queue_name,
123129
"message": message,
@@ -128,21 +134,19 @@ def _get_send_statement(
128134
@staticmethod
129135
def _get_send_batch_statement(
130136
queue_name: str, messages: List[dict], delay: int
131-
) -> Tuple[str, Dict[str, Any]]:
137+
) -> Tuple[text, Dict[str, Any]]:
132138
"""Get statement and params for send_batch.
133139
134-
Note: This uses PostgreSQL array literal format with escaped quotes.
135-
While not ideal, this approach balances SQL injection protection with
136-
cross-driver compatibility. The escaping is safe as long as:
137-
1. Input is a List[dict] (enforced by type hints)
138-
2. json.dumps produces valid JSON (guaranteed for dict inputs)
139-
3. Users do not pass pre-serialized JSON strings as dict values
140-
141-
A more robust solution would use SQLAlchemy's array types or driver-specific
142-
array adaptation, but that would sacrifice cross-driver compatibility.
140+
Note: This uses SQLAlchemy's bindparam with JSONB array type for proper
141+
cross-driver compatibility and type adaptation.
143142
"""
143+
stmt = text(
144+
"select * from pgmq.send_batch(:queue_name, :messages, :delay);"
145+
).bindparams(
146+
bindparam("messages", type_=ARRAY(JSONB))
147+
)
144148
return (
145-
"select * from pgmq.send_batch(:queue_name, CAST(:messages AS jsonb[]), :delay);",
149+
stmt,
146150
{
147151
"queue_name": queue_name,
148152
"messages": messages,
@@ -584,7 +588,7 @@ def send(
584588
The message ID.
585589
"""
586590
stmt, params = PGMQOperation._get_send_statement(queue_name, message, delay)
587-
row = session.execute(text(stmt), params).fetchone()
591+
row = session.execute(stmt, params).fetchone()
588592
if commit:
589593
session.commit()
590594
return row[0]
@@ -611,7 +615,7 @@ async def send_async(
611615
The message ID.
612616
"""
613617
stmt, params = PGMQOperation._get_send_statement(queue_name, message, delay)
614-
row = (await session.execute(text(stmt), params)).fetchone()
618+
row = (await session.execute(stmt, params)).fetchone()
615619
if commit:
616620
await session.commit()
617621
return row[0]
@@ -640,7 +644,7 @@ def send_batch(
640644
stmt, params = PGMQOperation._get_send_batch_statement(
641645
queue_name, messages, delay
642646
)
643-
rows = session.execute(text(stmt), params).fetchall()
647+
rows = session.execute(stmt, params).fetchall()
644648
if commit:
645649
session.commit()
646650
return [row[0] for row in rows]
@@ -669,7 +673,7 @@ async def send_batch_async(
669673
stmt, params = PGMQOperation._get_send_batch_statement(
670674
queue_name, messages, delay
671675
)
672-
rows = (await session.execute(text(stmt), params)).fetchall()
676+
rows = (await session.execute(stmt, params)).fetchall()
673677
if commit:
674678
await session.commit()
675679
return [row[0] for row in rows]

0 commit comments

Comments
 (0)