Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions pgmq_sqlalchemy/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def _validate_partition_interval(interval: Union[int, str]) -> str:
raise ValueError("Numeric partition interval must be positive")
return str(interval)

# Check if it's a numeric string
if interval.strip().isdigit():
numeric_value = int(interval.strip())
if numeric_value <= 0:
raise ValueError("Numeric partition interval must be positive")
return str(numeric_value)

# Validate time-based interval format
# Valid PostgreSQL interval formats: '1 day', '7 days', '1 hour', '1 month', etc.
time_pattern = r"^\d+\s+(microsecond|millisecond|second|minute|hour|day|week|month|year)s?$"
Expand Down Expand Up @@ -111,7 +118,7 @@ def _get_send_statement(
) -> Tuple[str, Dict[str, Any]]:
"""Get statement and params for send."""
return (
"select * from pgmq.send(:queue_name, :message::jsonb, :delay);",
"select * from pgmq.send(:queue_name, CAST(:message AS jsonb), :delay);",
{
"queue_name": queue_name,
"message": json.dumps(message),
Expand All @@ -123,12 +130,27 @@ def _get_send_statement(
def _get_send_batch_statement(
queue_name: str, messages: List[dict], delay: int
) -> Tuple[str, Dict[str, Any]]:
"""Get statement and params for send_batch."""
"""Get statement and params for send_batch.

Note: This uses PostgreSQL array literal format with escaped quotes.
While not ideal, this approach balances SQL injection protection with
cross-driver compatibility. The escaping is safe as long as:
1. Input is a List[dict] (enforced by type hints)
2. json.dumps produces valid JSON (guaranteed for dict inputs)
3. Users do not pass pre-serialized JSON strings as dict values

A more robust solution would use SQLAlchemy's array types or driver-specific
array adaptation, but that would sacrifice cross-driver compatibility.
"""
# Convert list of dicts to array of jsonb strings
# Need to escape quotes for PostgreSQL array literal format
jsonb_strings = [json.dumps(msg).replace('"', '\\"') for msg in messages]
array_literal = "{" + ",".join(f'"{js}"' for js in jsonb_strings) + "}"
return (
"select * from pgmq.send_batch(:queue_name, :messages::jsonb, :delay);",
"select * from pgmq.send_batch(:queue_name, CAST(:messages AS jsonb[]), :delay);",
{
"queue_name": queue_name,
"messages": json.dumps(messages),
"messages": array_literal,
"delay": delay,
},
)
Expand Down Expand Up @@ -198,7 +220,7 @@ def _get_delete_batch_statement(
) -> Tuple[str, Dict[str, Any]]:
"""Get statement and params for delete_batch."""
return (
"select pgmq.delete(:queue_name, msg_id) from unnest(:msg_ids::bigint[]) as msg_id;",
"select msg_id from unnest(CAST(:msg_ids AS bigint[])) as msg_id where pgmq.delete(:queue_name, msg_id);",
{"queue_name": queue_name, "msg_ids": msg_ids},
)

Expand All @@ -218,7 +240,7 @@ def _get_archive_batch_statement(
) -> Tuple[str, Dict[str, Any]]:
"""Get statement and params for archive_batch."""
return (
"select pgmq.archive(:queue_name, msg_id) from unnest(:msg_ids::bigint[]) as msg_id;",
"select msg_id from unnest(CAST(:msg_ids AS bigint[])) as msg_id where pgmq.archive(:queue_name, msg_id);",
{"queue_name": queue_name, "msg_ids": msg_ids},
)

Expand Down Expand Up @@ -367,6 +389,14 @@ def create_partitioned_queue(
session: SQLAlchemy session.
commit: Whether to commit the transaction.
"""
# Validate partition intervals
partition_interval = PGMQOperation._validate_partition_interval(
partition_interval
)
retention_interval = PGMQOperation._validate_partition_interval(
retention_interval
)

stmt, params = PGMQOperation._get_create_partitioned_queue_statement(
queue_name, partition_interval, retention_interval
)
Expand Down
Loading