Skip to content

Commit 990353c

Browse files
committed
Add execution progress handler following sqlite3 set_progress_handler pattern
Implement connection-level progress reporting as a PEP 249 vendor extension: - ProgressInfo NamedTuple and ProgressHandler type alias in connection.py - set_progress_handler(handler) method on Connection - Automatically sends enable_progress_events: true when handler is set - Handles execution_progress WebSocket events independently of the query state machine, with exception isolation - Export ProgressInfo from wherobots.db - Add --progress flag to tests/smoke.py for manual testing with rich output
1 parent a5fd480 commit 990353c

4 files changed

Lines changed: 83 additions & 2 deletions

File tree

tests/smoke.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from rich.table import Table
1212

1313
from wherobots.db import connect, connect_direct, errors
14+
from wherobots.db.connection import Connection, ProgressInfo
1415
from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE
15-
from wherobots.db.connection import Connection
1616
from wherobots.db.region import Region
1717
from wherobots.db.runtime import Runtime
1818
from wherobots.db.session_type import SessionType
@@ -54,6 +54,11 @@
5454
parser.add_argument(
5555
"--wide", help="Enable wide output", action="store_const", const=80, default=30
5656
)
57+
parser.add_argument(
58+
"--progress",
59+
help="Enable execution progress reporting",
60+
action="store_true",
61+
)
5762
parser.add_argument("sql", nargs="+", help="SQL query to execute")
5863
args = parser.parse_args()
5964

@@ -134,6 +139,26 @@ def execute(conn: Connection, sql: str) -> pandas.DataFrame | StoreResult:
134139

135140
try:
136141
with conn_func() as conn:
142+
if args.progress:
143+
console = Console(stderr=True)
144+
145+
def _on_progress(info: ProgressInfo) -> None:
146+
pct = (
147+
f"{info.tasks_completed / info.tasks_total * 100:.0f}%"
148+
if info.tasks_total
149+
else "?"
150+
)
151+
console.print(
152+
f" [dim]\\[progress][/dim] "
153+
f"[bold]{pct}[/bold] "
154+
f"{info.tasks_completed}/{info.tasks_total} tasks "
155+
f"[dim]({info.tasks_active} active)[/dim] "
156+
f"[dim]{info.execution_id[:8]}[/dim]",
157+
highlight=False,
158+
)
159+
160+
conn.set_progress_handler(_on_progress)
161+
137162
with concurrent.futures.ThreadPoolExecutor() as pool:
138163
futures = [pool.submit(execute, conn, s) for s in args.sql]
139164
for future in concurrent.futures.as_completed(futures):

wherobots/db/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .connection import Connection
1+
from .connection import Connection, ProgressInfo
22
from .cursor import Cursor
33
from .driver import connect, connect_direct
44
from .errors import (
@@ -18,6 +18,7 @@
1818
__all__ = [
1919
"Connection",
2020
"Cursor",
21+
"ProgressInfo",
2122
"connect",
2223
"connect_direct",
2324
"Error",

wherobots/db/connection.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@
2727
)
2828

2929

30+
@dataclass(frozen=True)
31+
class ProgressInfo:
32+
"""Progress information for a running query.
33+
34+
Mirrors the ``execution_progress`` event sent by the SQL session.
35+
"""
36+
37+
execution_id: str
38+
tasks_total: int
39+
tasks_completed: int
40+
tasks_active: int
41+
42+
43+
ProgressHandler = Callable[[ProgressInfo], None]
44+
"""A callable invoked with a :class:`ProgressInfo` on every progress event."""
45+
46+
3047
@dataclass
3148
class Query:
3249
sql: str
@@ -64,6 +81,7 @@ def __init__(
6481
self.__results_format = results_format
6582
self.__data_compression = data_compression
6683
self.__geometry_representation = geometry_representation
84+
self.__progress_handler: ProgressHandler | None = None
6785

6886
self.__queries: dict[str, Query] = {}
6987
self.__thread = threading.Thread(
@@ -89,6 +107,21 @@ def rollback(self) -> None:
89107
def cursor(self) -> Cursor:
90108
return Cursor(self.__execute_sql, self.__cancel_query)
91109

110+
def set_progress_handler(self, handler: ProgressHandler | None) -> None:
111+
"""Register a callback invoked for execution progress events.
112+
113+
When a handler is set, every ``execute_sql`` request automatically
114+
includes ``enable_progress_events: true`` so the SQL session streams
115+
progress updates for running queries.
116+
117+
Pass ``None`` to disable progress reporting.
118+
119+
This follows the `sqlite3 Connection.set_progress_handler()
120+
<https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.set_progress_handler>`_
121+
pattern (PEP 249 vendor extension).
122+
"""
123+
self.__progress_handler = handler
124+
92125
def __main_loop(self) -> None:
93126
"""Main background loop listening for messages from the SQL session."""
94127
logging.info("Starting background connection handling loop...")
@@ -116,6 +149,24 @@ def __listen(self) -> None:
116149
# Invalid event.
117150
return
118151

152+
# Progress events are independent of the query state machine and don't
153+
# require a tracked query — the handler is connection-level.
154+
if kind == EventKind.EXECUTION_PROGRESS:
155+
handler = self.__progress_handler
156+
if handler is not None:
157+
try:
158+
handler(
159+
ProgressInfo(
160+
execution_id=execution_id,
161+
tasks_total=message.get("tasks_total", 0),
162+
tasks_completed=message.get("tasks_completed", 0),
163+
tasks_active=message.get("tasks_active", 0),
164+
)
165+
)
166+
except Exception:
167+
logging.exception("Progress handler raised an exception")
168+
return
169+
119170
query = self.__queries.get(execution_id)
120171
if not query:
121172
logging.warning(
@@ -236,6 +287,9 @@ def __execute_sql(
236287
"statement": sql,
237288
}
238289

290+
if self.__progress_handler is not None:
291+
request["enable_progress_events"] = True
292+
239293
if store:
240294
request["store"] = {
241295
"format": store.format.value,

wherobots/db/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class EventKind(LowercaseStrEnum):
4545
STATE_UPDATED = auto()
4646
EXECUTION_RESULT = auto()
4747
ERROR = auto()
48+
EXECUTION_PROGRESS = auto()
4849

4950

5051
class ResultsFormat(LowercaseStrEnum):

0 commit comments

Comments
 (0)