Skip to content

Commit 8d3e6ab

Browse files
committed
feat: add cancel_event param and HTTP timeouts to connect()
Make connect() and connect_direct() interruptible from another thread via a threading.Event. Add 30s timeouts to all HTTP requests and the WebSocket handshake to prevent indefinite hangs. This enables callers (e.g. MCP server) to cleanly abort connection attempts when clients disconnect, preventing zombie thread accumulation and thread pool exhaustion.
1 parent 8630494 commit 8d3e6ab

7 files changed

Lines changed: 226 additions & 2 deletions

File tree

.github/workflows/test.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ jobs:
3838
cd tests
3939
python -c "import wherobots.db"
4040
41+
- name: Run tests
42+
run: pytest tests/
43+
4144
- name: Check build
4245
run: |
4346
uv build

CONTRIBUTING.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,50 @@
33
This project uses `uv`. Run `uv sync` after checking out the repository
44
to initialize your virtualenv with the project's dependencies.
55

6+
## Running tests
7+
8+
Unit tests live in `tests/` and run with pytest:
9+
10+
```bash
11+
uv run pytest tests/
12+
```
13+
14+
The `scripts/` directory contains integration scripts that require a live
15+
Wherobots environment and are not part of the automated test suite.
16+
17+
### Smoke test
18+
19+
`scripts/smoke.py` runs queries against a live Wherobots SQL session.
20+
It requires an API key (or token) and supports most `connect()` options
21+
via CLI flags.
22+
23+
```bash
24+
# Basic query with an API key
25+
uv run python scripts/smoke.py \
26+
--api-key-file ~/.wherobots/api-key \
27+
"SELECT 1"
28+
29+
# Specify runtime, region, and version
30+
uv run python scripts/smoke.py \
31+
--api-key-file ~/.wherobots/api-key \
32+
--runtime tiny --region aws-us-west-2 --version latest \
33+
"SELECT ST_AsText(ST_Point(1, 2))"
34+
35+
# Connect directly to an existing session via WebSocket URL
36+
uv run python scripts/smoke.py \
37+
--api-key-file ~/.wherobots/api-key \
38+
--ws-url wss://compute.example.com/sql/org/session-id \
39+
"SHOW TABLES"
40+
41+
# Enable debug logging and execution progress
42+
uv run python scripts/smoke.py \
43+
--api-key-file ~/.wherobots/api-key \
44+
--debug --progress \
45+
"SELECT * FROM wherobots_open_data.overture.places LIMIT 10"
46+
```
47+
48+
Run `uv run python scripts/smoke.py --help` for all available options.
49+
650
## Publish package to PyPI
751

852
When we are ready to release a new version `vx.y.z`, one of the maintainers should:

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,9 @@ users may find useful:
261261
your expected time between queries and effectively get a continuously
262262
running SQL session runtime without any complex connection management
263263
in your application.
264+
* `cancel_event`: a `threading.Event` that, when set, causes the
265+
connection attempt to abort promptly with an `InterfaceError`. This
266+
is useful when `connect()` is running in a background thread and the
267+
caller needs to interrupt it (e.g. on client disconnect or timeout).
268+
The event is checked before each HTTP request, between retry
269+
attempts, and before the WebSocket handshake.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "wherobots-python-dbapi"
3-
version = "0.26.1"
3+
version = "0.27.0"
44
description = "Python DB-API driver for Wherobots DB"
55
authors = [{ name = "Maxime Petazzoni", email = "[email protected]" }]
66
requires-python = ">=3.10, <4"
File renamed without changes.

tests/test_driver.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Tests for the connect() and connect_direct() driver functions."""
2+
3+
import threading
4+
import time
5+
from unittest.mock import MagicMock, patch
6+
7+
import pytest
8+
import requests
9+
10+
from wherobots.db.driver import (
11+
DEFAULT_HTTP_TIMEOUT,
12+
_check_cancelled,
13+
connect,
14+
connect_direct,
15+
)
16+
from wherobots.db.errors import InterfaceError
17+
18+
19+
class TestCheckCancelled:
20+
def test_none_event_is_noop(self):
21+
_check_cancelled(None)
22+
23+
def test_unset_event_is_noop(self):
24+
event = threading.Event()
25+
_check_cancelled(event)
26+
27+
def test_set_event_raises(self):
28+
event = threading.Event()
29+
event.set()
30+
with pytest.raises(InterfaceError, match="cancelled by caller"):
31+
_check_cancelled(event)
32+
33+
34+
class TestConnectCancelEvent:
35+
@patch("wherobots.db.driver.requests.post")
36+
def test_cancel_before_post(self, mock_post):
37+
"""cancel_event set before connect() should raise immediately without making HTTP calls."""
38+
cancel = threading.Event()
39+
cancel.set()
40+
41+
with pytest.raises(InterfaceError, match="cancelled by caller"):
42+
connect(api_key="test-key", cancel_event=cancel)
43+
44+
mock_post.assert_not_called()
45+
46+
@patch("wherobots.db.driver.requests.get")
47+
@patch("wherobots.db.driver.requests.post")
48+
def test_cancel_during_polling(self, mock_post, mock_get):
49+
"""cancel_event set during session polling should abort the retry loop."""
50+
# POST succeeds with redirect
51+
post_resp = MagicMock()
52+
post_resp.status_code = 200
53+
post_resp.url = "https://api.example.com/sql/session/test-id"
54+
post_resp.raise_for_status = MagicMock()
55+
mock_post.return_value = post_resp
56+
57+
# GET returns INITIALIZING (triggers TryAgain)
58+
get_resp = MagicMock()
59+
get_resp.status_code = 200
60+
get_resp.raise_for_status = MagicMock()
61+
get_resp.json.return_value = {"status": "INITIALIZING"}
62+
mock_get.return_value = get_resp
63+
64+
cancel = threading.Event()
65+
66+
# Set cancel after a short delay (during polling)
67+
def set_cancel():
68+
time.sleep(0.1)
69+
cancel.set()
70+
71+
t = threading.Thread(target=set_cancel)
72+
t.start()
73+
74+
with pytest.raises(InterfaceError, match="cancelled by caller"):
75+
connect(api_key="test-key", cancel_event=cancel, wait_timeout=10)
76+
77+
t.join()
78+
79+
@patch("wherobots.db.driver.requests.post")
80+
def test_http_timeout_on_post(self, mock_post):
81+
"""requests.post should be called with a timeout."""
82+
post_resp = MagicMock()
83+
post_resp.status_code = 401
84+
post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp)
85+
post_resp.json.side_effect = requests.JSONDecodeError("", "", 0)
86+
mock_post.return_value = post_resp
87+
88+
with pytest.raises(InterfaceError, match="Failed to create SQL session"):
89+
connect(api_key="test-key")
90+
91+
_, kwargs = mock_post.call_args
92+
assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT
93+
94+
@patch("wherobots.db.driver.requests.get")
95+
@patch("wherobots.db.driver.requests.post")
96+
def test_http_timeout_on_get(self, mock_post, mock_get):
97+
"""requests.get in the polling loop should be called with a timeout."""
98+
post_resp = MagicMock()
99+
post_resp.status_code = 200
100+
post_resp.url = "https://api.example.com/sql/session/test-id"
101+
post_resp.raise_for_status = MagicMock()
102+
mock_post.return_value = post_resp
103+
104+
get_resp = MagicMock()
105+
get_resp.status_code = 200
106+
get_resp.raise_for_status = MagicMock()
107+
get_resp.json.return_value = {
108+
"status": "READY",
109+
"appMeta": {"url": "https://compute.example.com/sql/org/session-id"},
110+
}
111+
mock_get.return_value = get_resp
112+
113+
# Patch connect_direct to avoid actual WebSocket connection
114+
with patch("wherobots.db.driver.connect_direct") as mock_cd:
115+
mock_cd.return_value = MagicMock()
116+
connect(api_key="test-key")
117+
118+
_, kwargs = mock_get.call_args
119+
assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT
120+
121+
@patch("wherobots.db.driver.requests.post")
122+
def test_connect_without_cancel_event(self, mock_post):
123+
"""connect() without cancel_event should work as before (backward compat)."""
124+
post_resp = MagicMock()
125+
post_resp.status_code = 401
126+
post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp)
127+
post_resp.json.side_effect = requests.JSONDecodeError("", "", 0)
128+
mock_post.return_value = post_resp
129+
130+
with pytest.raises(InterfaceError):
131+
connect(api_key="test-key")
132+
133+
134+
class TestConnectDirectCancelEvent:
135+
@patch("wherobots.db.driver.websockets.sync.client.connect")
136+
def test_cancel_before_ws_connect(self, mock_ws):
137+
cancel = threading.Event()
138+
cancel.set()
139+
140+
with pytest.raises(InterfaceError, match="cancelled by caller"):
141+
connect_direct(
142+
uri="wss://compute.example.com/sql/org/session-id",
143+
cancel_event=cancel,
144+
)
145+
146+
mock_ws.assert_not_called()

wherobots/db/driver.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import platform
1212
import requests
1313
import tenacity
14+
import threading
1415
from typing import Final, Union, Dict
1516
import urllib.parse
1617
import websockets.exceptions
@@ -51,6 +52,9 @@
5152
# This follows the industry-standard set used by urllib3.util.Retry's status_forcelist.
5253
TRANSIENT_HTTP_STATUS_CODES = {429, 502, 503, 504}
5354

55+
# Default timeout for individual HTTP requests (connect + read), in seconds.
56+
DEFAULT_HTTP_TIMEOUT = 30
57+
5458

5559
def gen_user_agent_header():
5660
try:
@@ -79,6 +83,7 @@ def connect(
7983
results_format: Union[ResultsFormat, None] = None,
8084
data_compression: Union[DataCompression, None] = None,
8185
geometry_representation: Union[GeometryRepresentation, None] = None,
86+
cancel_event: Union[threading.Event, None] = None,
8287
) -> Connection:
8388
if not token and not api_key:
8489
raise ValueError("At least one of `token` or `api_key` is required")
@@ -109,6 +114,8 @@ def connect(
109114
if not host.startswith("http:"):
110115
host = f"https://{host}"
111116

117+
_check_cancelled(cancel_event)
118+
112119
try:
113120
resp = requests.post(
114121
url=f"{host}/sql/session",
@@ -120,6 +127,7 @@ def connect(
120127
"sessionType": session_type.value,
121128
},
122129
headers=headers,
130+
timeout=DEFAULT_HTTP_TIMEOUT,
123131
)
124132
resp.raise_for_status()
125133
except requests.HTTPError as e:
@@ -149,10 +157,12 @@ def connect(
149157
)
150158
| tenacity.retry_if_exception_type(tenacity.TryAgain)
151159
),
160+
before_sleep=lambda _: _check_cancelled(cancel_event),
152161
reraise=True,
153162
)
154163
def get_session_uri() -> str:
155-
r = requests.get(session_id_url, headers=headers)
164+
_check_cancelled(cancel_event)
165+
r = requests.get(session_id_url, headers=headers, timeout=DEFAULT_HTTP_TIMEOUT)
156166
r.raise_for_status()
157167
payload = r.json()
158168
status = AppStatus(payload.get("status"))
@@ -169,6 +179,8 @@ def get_session_uri() -> str:
169179
logging.info("Getting SQL session status from %s ...", session_id_url)
170180
session_uri = get_session_uri()
171181
logging.debug("SQL session URI from app status: %s", session_uri)
182+
except InterfaceError:
183+
raise
172184
except Exception as e:
173185
raise InterfaceError("Could not acquire SQL session!", e)
174186

@@ -179,9 +191,16 @@ def get_session_uri() -> str:
179191
results_format=results_format,
180192
data_compression=data_compression,
181193
geometry_representation=geometry_representation,
194+
cancel_event=cancel_event,
182195
)
183196

184197

198+
def _check_cancelled(cancel_event: Union[threading.Event, None]) -> None:
199+
"""Raise InterfaceError if the cancel event is set."""
200+
if cancel_event is not None and cancel_event.is_set():
201+
raise InterfaceError("Connection cancelled by caller")
202+
203+
185204
def http_to_ws(uri: str) -> str:
186205
"""Converts an HTTP URI to a WebSocket URI."""
187206
parsed = urllib.parse.urlparse(uri)
@@ -199,6 +218,7 @@ def connect_direct(
199218
results_format: Union[ResultsFormat, None] = None,
200219
data_compression: Union[DataCompression, None] = None,
201220
geometry_representation: Union[GeometryRepresentation, None] = None,
221+
cancel_event: Union[threading.Event, None] = None,
202222
) -> Connection:
203223
uri_with_protocol = f"{uri}/{protocol}"
204224
ssl_context = ssl.create_default_context()
@@ -215,19 +235,24 @@ def connect_direct(
215235
websockets.exceptions.InvalidHandshake,
216236
)
217237
),
238+
before_sleep=lambda _: _check_cancelled(cancel_event),
218239
reraise=True,
219240
)
220241
def ws_connect() -> websockets.sync.client.ClientConnection:
242+
_check_cancelled(cancel_event)
221243
logging.info("Connecting to SQL session at %s ...", uri_with_protocol)
222244
return websockets.sync.client.connect(
223245
uri=uri_with_protocol,
224246
additional_headers=headers,
225247
max_size=MAX_MESSAGE_SIZE,
248+
open_timeout=DEFAULT_HTTP_TIMEOUT,
226249
ssl=ssl_context,
227250
)
228251

229252
try:
230253
ws = ws_connect()
254+
except InterfaceError:
255+
raise
231256
except Exception as e:
232257
raise InterfaceError("Failed to connect to SQL session!") from e
233258

0 commit comments

Comments
 (0)