Skip to content

Commit 849a8f1

Browse files
Copilotjason810496
andauthored
Address code review feedback: fix SQL injection, QueueMetrics bug, reduce boilerplate (#22)
* Initial plan * Fix SQL injection vulnerability and QueueMetrics bug in operation.py Co-authored-by: jason810496 <[email protected]> * Refactor queue.py to reduce boilerplate with _execute_operation helper Co-authored-by: jason810496 <[email protected]> * Add comprehensive tests for PGMQOperation class Co-authored-by: jason810496 <[email protected]> * Fix json import and send_batch parameter handling Co-authored-by: jason810496 <[email protected]> * Fix test exception handling to be more specific Co-authored-by: jason810496 <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: jason810496 <[email protected]>
1 parent f01f52c commit 849a8f1

3 files changed

Lines changed: 508 additions & 190 deletions

File tree

pgmq_sqlalchemy/operation.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import List, Optional, Tuple, Dict, Any, Union
22
import re
3+
import json
34

45
from sqlalchemy import text
56
from sqlalchemy.orm import Session
67
from sqlalchemy.ext.asyncio import AsyncSession
78

89
from .schema import Message, QueueMetrics
9-
from ._utils import encode_dict_to_psql, encode_list_to_psql
1010

1111

1212
class PGMQOperation:
@@ -106,19 +106,31 @@ def _get_list_queues_statement() -> Tuple[str, Dict[str, Any]]:
106106
return "select queue_name from pgmq.list_queues();", {}
107107

108108
@staticmethod
109-
def _get_send_statement(queue_name: str, message: dict, delay: int) -> str:
110-
"""Get statement for send (no params, using f-string)."""
111-
encoded_message = encode_dict_to_psql(message)
112-
return f"select * from pgmq.send('{queue_name}',{encoded_message},{delay});"
109+
def _get_send_statement(
110+
queue_name: str, message: dict, delay: int
111+
) -> Tuple[str, Dict[str, Any]]:
112+
"""Get statement and params for send."""
113+
return (
114+
"select * from pgmq.send(:queue_name, :message::jsonb, :delay);",
115+
{
116+
"queue_name": queue_name,
117+
"message": json.dumps(message),
118+
"delay": delay,
119+
},
120+
)
113121

114122
@staticmethod
115123
def _get_send_batch_statement(
116124
queue_name: str, messages: List[dict], delay: int
117-
) -> str:
118-
"""Get statement for send_batch (no params, using f-string)."""
119-
encoded_messages = encode_list_to_psql(messages)
125+
) -> Tuple[str, Dict[str, Any]]:
126+
"""Get statement and params for send_batch."""
120127
return (
121-
f"select * from pgmq.send_batch('{queue_name}',{encoded_messages},{delay});"
128+
"select * from pgmq.send_batch(:queue_name, :messages::jsonb, :delay);",
129+
{
130+
"queue_name": queue_name,
131+
"messages": json.dumps(messages),
132+
"delay": delay,
133+
},
122134
)
123135

124136
@staticmethod
@@ -546,8 +558,8 @@ def send(
546558
Returns:
547559
The message ID.
548560
"""
549-
stmt = PGMQOperation._get_send_statement(queue_name, message, delay)
550-
row = session.execute(text(stmt)).fetchone()
561+
stmt, params = PGMQOperation._get_send_statement(queue_name, message, delay)
562+
row = session.execute(text(stmt), params).fetchone()
551563
if commit:
552564
session.commit()
553565
return row[0]
@@ -573,8 +585,8 @@ async def send_async(
573585
Returns:
574586
The message ID.
575587
"""
576-
stmt = PGMQOperation._get_send_statement(queue_name, message, delay)
577-
row = (await session.execute(text(stmt))).fetchone()
588+
stmt, params = PGMQOperation._get_send_statement(queue_name, message, delay)
589+
row = (await session.execute(text(stmt), params)).fetchone()
578590
if commit:
579591
await session.commit()
580592
return row[0]
@@ -600,8 +612,8 @@ def send_batch(
600612
Returns:
601613
List of message IDs.
602614
"""
603-
stmt = PGMQOperation._get_send_batch_statement(queue_name, messages, delay)
604-
rows = session.execute(text(stmt)).fetchall()
615+
stmt, params = PGMQOperation._get_send_batch_statement(queue_name, messages, delay)
616+
rows = session.execute(text(stmt), params).fetchall()
605617
if commit:
606618
session.commit()
607619
return [row[0] for row in rows]
@@ -627,8 +639,8 @@ async def send_batch_async(
627639
Returns:
628640
List of message IDs.
629641
"""
630-
stmt = PGMQOperation._get_send_batch_statement(queue_name, messages, delay)
631-
rows = (await session.execute(text(stmt))).fetchall()
642+
stmt, params = PGMQOperation._get_send_batch_statement(queue_name, messages, delay)
643+
rows = (await session.execute(text(stmt), params)).fetchall()
632644
if commit:
633645
await session.commit()
634646
return [row[0] for row in rows]
@@ -1250,7 +1262,6 @@ def metrics(
12501262
newest_msg_age_sec=row[2],
12511263
oldest_msg_age_sec=row[3],
12521264
total_messages=row[4],
1253-
scrape_time=row[5],
12541265
)
12551266

12561267
@staticmethod
@@ -1282,7 +1293,6 @@ async def metrics_async(
12821293
newest_msg_age_sec=row[2],
12831294
oldest_msg_age_sec=row[3],
12841295
total_messages=row[4],
1285-
scrape_time=row[5],
12861296
)
12871297

12881298
@staticmethod
@@ -1313,7 +1323,6 @@ def metrics_all(
13131323
newest_msg_age_sec=row[2],
13141324
oldest_msg_age_sec=row[3],
13151325
total_messages=row[4],
1316-
scrape_time=row[5],
13171326
)
13181327
for row in rows
13191328
]
@@ -1346,7 +1355,6 @@ async def metrics_all_async(
13461355
newest_msg_age_sec=row[2],
13471356
oldest_msg_age_sec=row[3],
13481357
total_messages=row[4],
1349-
scrape_time=row[5],
13501358
)
13511359
for row in rows
13521360
]

0 commit comments

Comments
 (0)