Skip to content

Commit bf3410c

Browse files
authored
Merge pull request #66 from wherobots/peter/add-cancel-event-and-http-timeouts
feat: add cancel_event param and HTTP timeouts to connect()
2 parents 8630494 + 8d3e6ab commit bf3410c

7 files changed

Lines changed: 226 additions & 2 deletions

File tree

.github/workflows/test.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ jobs:
3838
cd tests
3939
python -c "import wherobots.db"
4040
41+
- name: Run tests
42+
run: pytest tests/
43+
4144
- name: Check build
4245
run: |
4346
uv build

CONTRIBUTING.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,50 @@
33
This project uses `uv`. Run `uv sync` after checking out the repository
44
to initialize your virtualenv with the project's dependencies.
55

6+
## Running tests
7+
8+
Unit tests live in `tests/` and run with pytest:
9+
10+
```bash
11+
uv run pytest tests/
12+
```
13+
14+
The `scripts/` directory contains integration scripts that require a live
15+
Wherobots environment and are not part of the automated test suite.
16+
17+
### Smoke test
18+
19+
`scripts/smoke.py` runs queries against a live Wherobots SQL session.
20+
It requires an API key (or token) and supports most `connect()` options
21+
via CLI flags.
22+
23+
```bash
24+
# Basic query with an API key
25+
uv run python scripts/smoke.py \
26+
--api-key-file ~/.wherobots/api-key \
27+
"SELECT 1"
28+
29+
# Specify runtime, region, and version
30+
uv run python scripts/smoke.py \
31+
--api-key-file ~/.wherobots/api-key \
32+
--runtime tiny --region aws-us-west-2 --version latest \
33+
"SELECT ST_AsText(ST_Point(1, 2))"
34+
35+
# Connect directly to an existing session via WebSocket URL
36+
uv run python scripts/smoke.py \
37+
--api-key-file ~/.wherobots/api-key \
38+
--ws-url wss://compute.example.com/sql/org/session-id \
39+
"SHOW TABLES"
40+
41+
# Enable debug logging and execution progress
42+
uv run python scripts/smoke.py \
43+
--api-key-file ~/.wherobots/api-key \
44+
--debug --progress \
45+
"SELECT * FROM wherobots_open_data.overture.places LIMIT 10"
46+
```
47+
48+
Run `uv run python scripts/smoke.py --help` for all available options.
49+
650
## Publish package to PyPI
751

852
When we are ready to release a new version `vx.y.z`, one of the maintainers should:

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,9 @@ users may find useful:
261261
your expected time between queries and effectively get a continuously
262262
running SQL session runtime without any complex connection management
263263
in your application.
264+
* `cancel_event`: a `threading.Event` that, when set, causes the
265+
connection attempt to abort promptly with an `InterfaceError`. This
266+
is useful when `connect()` is running in a background thread and the
267+
caller needs to interrupt it (e.g. on client disconnect or timeout).
268+
The event is checked before each HTTP request, between retry
269+
attempts, and before the WebSocket handshake.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "wherobots-python-dbapi"
3-
version = "0.26.1"
3+
version = "0.27.0"
44
description = "Python DB-API driver for Wherobots DB"
55
authors = [{ name = "Maxime Petazzoni", email = "[email protected]" }]
66
requires-python = ">=3.10, <4"
File renamed without changes.

tests/test_driver.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Tests for the connect() and connect_direct() driver functions."""
2+
3+
import threading
4+
import time
5+
from unittest.mock import MagicMock, patch
6+
7+
import pytest
8+
import requests
9+
10+
from wherobots.db.driver import (
11+
DEFAULT_HTTP_TIMEOUT,
12+
_check_cancelled,
13+
connect,
14+
connect_direct,
15+
)
16+
from wherobots.db.errors import InterfaceError
17+
18+
19+
class TestCheckCancelled:
20+
def test_none_event_is_noop(self):
21+
_check_cancelled(None)
22+
23+
def test_unset_event_is_noop(self):
24+
event = threading.Event()
25+
_check_cancelled(event)
26+
27+
def test_set_event_raises(self):
28+
event = threading.Event()
29+
event.set()
30+
with pytest.raises(InterfaceError, match="cancelled by caller"):
31+
_check_cancelled(event)
32+
33+
34+
class TestConnectCancelEvent:
35+
@patch("wherobots.db.driver.requests.post")
36+
def test_cancel_before_post(self, mock_post):
37+
"""cancel_event set before connect() should raise immediately without making HTTP calls."""
38+
cancel = threading.Event()
39+
cancel.set()
40+
41+
with pytest.raises(InterfaceError, match="cancelled by caller"):
42+
connect(api_key="test-key", cancel_event=cancel)
43+
44+
mock_post.assert_not_called()
45+
46+
@patch("wherobots.db.driver.requests.get")
47+
@patch("wherobots.db.driver.requests.post")
48+
def test_cancel_during_polling(self, mock_post, mock_get):
49+
"""cancel_event set during session polling should abort the retry loop."""
50+
# POST succeeds with redirect
51+
post_resp = MagicMock()
52+
post_resp.status_code = 200
53+
post_resp.url = "https://api.example.com/sql/session/test-id"
54+
post_resp.raise_for_status = MagicMock()
55+
mock_post.return_value = post_resp
56+
57+
# GET returns INITIALIZING (triggers TryAgain)
58+
get_resp = MagicMock()
59+
get_resp.status_code = 200
60+
get_resp.raise_for_status = MagicMock()
61+
get_resp.json.return_value = {"status": "INITIALIZING"}
62+
mock_get.return_value = get_resp
63+
64+
cancel = threading.Event()
65+
66+
# Set cancel after a short delay (during polling)
67+
def set_cancel():
68+
time.sleep(0.1)
69+
cancel.set()
70+
71+
t = threading.Thread(target=set_cancel)
72+
t.start()
73+
74+
with pytest.raises(InterfaceError, match="cancelled by caller"):
75+
connect(api_key="test-key", cancel_event=cancel, wait_timeout=10)
76+
77+
t.join()
78+
79+
@patch("wherobots.db.driver.requests.post")
80+
def test_http_timeout_on_post(self, mock_post):
81+
"""requests.post should be called with a timeout."""
82+
post_resp = MagicMock()
83+
post_resp.status_code = 401
84+
post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp)
85+
post_resp.json.side_effect = requests.JSONDecodeError("", "", 0)
86+
mock_post.return_value = post_resp
87+
88+
with pytest.raises(InterfaceError, match="Failed to create SQL session"):
89+
connect(api_key="test-key")
90+
91+
_, kwargs = mock_post.call_args
92+
assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT
93+
94+
@patch("wherobots.db.driver.requests.get")
95+
@patch("wherobots.db.driver.requests.post")
96+
def test_http_timeout_on_get(self, mock_post, mock_get):
97+
"""requests.get in the polling loop should be called with a timeout."""
98+
post_resp = MagicMock()
99+
post_resp.status_code = 200
100+
post_resp.url = "https://api.example.com/sql/session/test-id"
101+
post_resp.raise_for_status = MagicMock()
102+
mock_post.return_value = post_resp
103+
104+
get_resp = MagicMock()
105+
get_resp.status_code = 200
106+
get_resp.raise_for_status = MagicMock()
107+
get_resp.json.return_value = {
108+
"status": "READY",
109+
"appMeta": {"url": "https://compute.example.com/sql/org/session-id"},
110+
}
111+
mock_get.return_value = get_resp
112+
113+
# Patch connect_direct to avoid actual WebSocket connection
114+
with patch("wherobots.db.driver.connect_direct") as mock_cd:
115+
mock_cd.return_value = MagicMock()
116+
connect(api_key="test-key")
117+
118+
_, kwargs = mock_get.call_args
119+
assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT
120+
121+
@patch("wherobots.db.driver.requests.post")
122+
def test_connect_without_cancel_event(self, mock_post):
123+
"""connect() without cancel_event should work as before (backward compat)."""
124+
post_resp = MagicMock()
125+
post_resp.status_code = 401
126+
post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp)
127+
post_resp.json.side_effect = requests.JSONDecodeError("", "", 0)
128+
mock_post.return_value = post_resp
129+
130+
with pytest.raises(InterfaceError):
131+
connect(api_key="test-key")
132+
133+
134+
class TestConnectDirectCancelEvent:
135+
@patch("wherobots.db.driver.websockets.sync.client.connect")
136+
def test_cancel_before_ws_connect(self, mock_ws):
137+
cancel = threading.Event()
138+
cancel.set()
139+
140+
with pytest.raises(InterfaceError, match="cancelled by caller"):
141+
connect_direct(
142+
uri="wss://compute.example.com/sql/org/session-id",
143+
cancel_event=cancel,
144+
)
145+
146+
mock_ws.assert_not_called()

wherobots/db/driver.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import platform
1212
import requests
1313
import tenacity
14+
import threading
1415
from typing import Final, Union, Dict
1516
import urllib.parse
1617
import websockets.exceptions
@@ -51,6 +52,9 @@
5152
# This follows the industry-standard set used by urllib3.util.Retry's status_forcelist.
5253
TRANSIENT_HTTP_STATUS_CODES = {429, 502, 503, 504}
5354

55+
# Default timeout for individual HTTP requests (connect + read), in seconds.
56+
DEFAULT_HTTP_TIMEOUT = 30
57+
5458

5559
def gen_user_agent_header():
5660
try:
@@ -79,6 +83,7 @@ def connect(
7983
results_format: Union[ResultsFormat, None] = None,
8084
data_compression: Union[DataCompression, None] = None,
8185
geometry_representation: Union[GeometryRepresentation, None] = None,
86+
cancel_event: Union[threading.Event, None] = None,
8287
) -> Connection:
8388
if not token and not api_key:
8489
raise ValueError("At least one of `token` or `api_key` is required")
@@ -109,6 +114,8 @@ def connect(
109114
if not host.startswith("http:"):
110115
host = f"https://{host}"
111116

117+
_check_cancelled(cancel_event)
118+
112119
try:
113120
resp = requests.post(
114121
url=f"{host}/sql/session",
@@ -120,6 +127,7 @@ def connect(
120127
"sessionType": session_type.value,
121128
},
122129
headers=headers,
130+
timeout=DEFAULT_HTTP_TIMEOUT,
123131
)
124132
resp.raise_for_status()
125133
except requests.HTTPError as e:
@@ -149,10 +157,12 @@ def connect(
149157
)
150158
| tenacity.retry_if_exception_type(tenacity.TryAgain)
151159
),
160+
before_sleep=lambda _: _check_cancelled(cancel_event),
152161
reraise=True,
153162
)
154163
def get_session_uri() -> str:
155-
r = requests.get(session_id_url, headers=headers)
164+
_check_cancelled(cancel_event)
165+
r = requests.get(session_id_url, headers=headers, timeout=DEFAULT_HTTP_TIMEOUT)
156166
r.raise_for_status()
157167
payload = r.json()
158168
status = AppStatus(payload.get("status"))
@@ -169,6 +179,8 @@ def get_session_uri() -> str:
169179
logging.info("Getting SQL session status from %s ...", session_id_url)
170180
session_uri = get_session_uri()
171181
logging.debug("SQL session URI from app status: %s", session_uri)
182+
except InterfaceError:
183+
raise
172184
except Exception as e:
173185
raise InterfaceError("Could not acquire SQL session!", e)
174186

@@ -179,9 +191,16 @@ def get_session_uri() -> str:
179191
results_format=results_format,
180192
data_compression=data_compression,
181193
geometry_representation=geometry_representation,
194+
cancel_event=cancel_event,
182195
)
183196

184197

198+
def _check_cancelled(cancel_event: Union[threading.Event, None]) -> None:
199+
"""Raise InterfaceError if the cancel event is set."""
200+
if cancel_event is not None and cancel_event.is_set():
201+
raise InterfaceError("Connection cancelled by caller")
202+
203+
185204
def http_to_ws(uri: str) -> str:
186205
"""Converts an HTTP URI to a WebSocket URI."""
187206
parsed = urllib.parse.urlparse(uri)
@@ -199,6 +218,7 @@ def connect_direct(
199218
results_format: Union[ResultsFormat, None] = None,
200219
data_compression: Union[DataCompression, None] = None,
201220
geometry_representation: Union[GeometryRepresentation, None] = None,
221+
cancel_event: Union[threading.Event, None] = None,
202222
) -> Connection:
203223
uri_with_protocol = f"{uri}/{protocol}"
204224
ssl_context = ssl.create_default_context()
@@ -215,19 +235,24 @@ def connect_direct(
215235
websockets.exceptions.InvalidHandshake,
216236
)
217237
),
238+
before_sleep=lambda _: _check_cancelled(cancel_event),
218239
reraise=True,
219240
)
220241
def ws_connect() -> websockets.sync.client.ClientConnection:
242+
_check_cancelled(cancel_event)
221243
logging.info("Connecting to SQL session at %s ...", uri_with_protocol)
222244
return websockets.sync.client.connect(
223245
uri=uri_with_protocol,
224246
additional_headers=headers,
225247
max_size=MAX_MESSAGE_SIZE,
248+
open_timeout=DEFAULT_HTTP_TIMEOUT,
226249
ssl=ssl_context,
227250
)
228251

229252
try:
230253
ws = ws_connect()
254+
except InterfaceError:
255+
raise
231256
except Exception as e:
232257
raise InterfaceError("Failed to connect to SQL session!") from e
233258

0 commit comments

Comments
 (0)