Skip to content

Commit f0f7fb6

Browse files
committed
Add execution progress handler following sqlite3 set_progress_handler pattern
Implement connection-level progress reporting as a PEP 249 vendor extension: - ProgressInfo NamedTuple and ProgressHandler type alias in connection.py - set_progress_handler(handler) method on Connection - Automatically sends enable_progress_events: true when handler is set - Handles execution_progress WebSocket events independently of the query state machine, with exception isolation - Export ProgressInfo from wherobots.db - Add --progress flag to tests/smoke.py for manual testing with rich output
1 parent a5fd480 commit f0f7fb6

4 files changed

Lines changed: 83 additions & 3 deletions

File tree

tests/smoke.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from rich.table import Table
1212

1313
from wherobots.db import connect, connect_direct, errors
14+
from wherobots.db.connection import Connection, ProgressInfo
1415
from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE
15-
from wherobots.db.connection import Connection
1616
from wherobots.db.region import Region
1717
from wherobots.db.runtime import Runtime
1818
from wherobots.db.session_type import SessionType
@@ -54,6 +54,11 @@
5454
parser.add_argument(
5555
"--wide", help="Enable wide output", action="store_const", const=80, default=30
5656
)
57+
parser.add_argument(
58+
"--progress",
59+
help="Enable execution progress reporting",
60+
action="store_true",
61+
)
5762
parser.add_argument("sql", nargs="+", help="SQL query to execute")
5863
args = parser.parse_args()
5964

@@ -134,6 +139,26 @@ def execute(conn: Connection, sql: str) -> pandas.DataFrame | StoreResult:
134139

135140
try:
136141
with conn_func() as conn:
142+
if args.progress:
143+
console = Console(stderr=True)
144+
145+
def _on_progress(info: ProgressInfo) -> None:
146+
pct = (
147+
f"{info.tasks_completed / info.tasks_total * 100:.0f}%"
148+
if info.tasks_total
149+
else "?"
150+
)
151+
console.print(
152+
f" [dim]\\[progress][/dim] "
153+
f"[bold]{pct}[/bold] "
154+
f"{info.tasks_completed}/{info.tasks_total} tasks "
155+
f"[dim]({info.tasks_active} active)[/dim] "
156+
f"[dim]{info.execution_id[:8]}[/dim]",
157+
highlight=False,
158+
)
159+
160+
conn.set_progress_handler(_on_progress)
161+
137162
with concurrent.futures.ThreadPoolExecutor() as pool:
138163
futures = [pool.submit(execute, conn, s) for s in args.sql]
139164
for future in concurrent.futures.as_completed(futures):

wherobots/db/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .connection import Connection
1+
from .connection import Connection, ProgressInfo
22
from .cursor import Cursor
33
from .driver import connect, connect_direct
44
from .errors import (
@@ -18,6 +18,7 @@
1818
__all__ = [
1919
"Connection",
2020
"Cursor",
21+
"ProgressInfo",
2122
"connect",
2223
"connect_direct",
2324
"Error",

wherobots/db/connection.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import threading
55
import uuid
66
from dataclasses import dataclass
7-
from typing import Any, Callable, Dict
7+
from typing import Any, Callable, Dict, NamedTuple
88

99
import pandas
1010
import pyarrow
@@ -27,6 +27,22 @@
2727
)
2828

2929

30+
class ProgressInfo(NamedTuple):
31+
"""Progress information for a running query.
32+
33+
Mirrors the ``execution_progress`` event sent by the SQL session.
34+
"""
35+
36+
execution_id: str
37+
tasks_total: int
38+
tasks_completed: int
39+
tasks_active: int
40+
41+
42+
ProgressHandler = Callable[[ProgressInfo], None]
43+
"""A callable invoked with a :class:`ProgressInfo` on every progress event."""
44+
45+
3046
@dataclass
3147
class Query:
3248
sql: str
@@ -64,6 +80,7 @@ def __init__(
6480
self.__results_format = results_format
6581
self.__data_compression = data_compression
6682
self.__geometry_representation = geometry_representation
83+
self.__progress_handler: ProgressHandler | None = None
6784

6885
self.__queries: dict[str, Query] = {}
6986
self.__thread = threading.Thread(
@@ -89,6 +106,21 @@ def rollback(self) -> None:
89106
def cursor(self) -> Cursor:
90107
return Cursor(self.__execute_sql, self.__cancel_query)
91108

109+
def set_progress_handler(self, handler: ProgressHandler | None) -> None:
110+
"""Register a callback invoked for execution progress events.
111+
112+
When a handler is set, every ``execute_sql`` request automatically
113+
includes ``enable_progress_events: true`` so the SQL session streams
114+
progress updates for running queries.
115+
116+
Pass ``None`` to disable progress reporting.
117+
118+
This follows the `sqlite3 Connection.set_progress_handler()
119+
<https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.set_progress_handler>`_
120+
pattern (PEP 249 vendor extension).
121+
"""
122+
self.__progress_handler = handler
123+
92124
def __main_loop(self) -> None:
93125
"""Main background loop listening for messages from the SQL session."""
94126
logging.info("Starting background connection handling loop...")
@@ -116,6 +148,24 @@ def __listen(self) -> None:
116148
# Invalid event.
117149
return
118150

151+
# Progress events are independent of the query state machine and don't
152+
# require a tracked query — the handler is connection-level.
153+
if kind == EventKind.EXECUTION_PROGRESS:
154+
handler = self.__progress_handler
155+
if handler is not None:
156+
try:
157+
handler(
158+
ProgressInfo(
159+
execution_id=execution_id,
160+
tasks_total=message.get("tasks_total", 0),
161+
tasks_completed=message.get("tasks_completed", 0),
162+
tasks_active=message.get("tasks_active", 0),
163+
)
164+
)
165+
except Exception:
166+
logging.exception("Progress handler raised an exception")
167+
return
168+
119169
query = self.__queries.get(execution_id)
120170
if not query:
121171
logging.warning(
@@ -236,6 +286,9 @@ def __execute_sql(
236286
"statement": sql,
237287
}
238288

289+
if self.__progress_handler is not None:
290+
request["enable_progress_events"] = True
291+
239292
if store:
240293
request["store"] = {
241294
"format": store.format.value,

wherobots/db/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class EventKind(LowercaseStrEnum):
4545
STATE_UPDATED = auto()
4646
EXECUTION_RESULT = auto()
4747
ERROR = auto()
48+
EXECUTION_PROGRESS = auto()
4849

4950

5051
class ResultsFormat(LowercaseStrEnum):

0 commit comments

Comments
 (0)