|
| 1 | +"""Tests for the connect() and connect_direct() driver functions.""" |
| 2 | + |
| 3 | +import threading |
| 4 | +import time |
| 5 | +from unittest.mock import MagicMock, patch |
| 6 | + |
| 7 | +import pytest |
| 8 | +import requests |
| 9 | + |
| 10 | +from wherobots.db.driver import ( |
| 11 | + DEFAULT_HTTP_TIMEOUT, |
| 12 | + _check_cancelled, |
| 13 | + connect, |
| 14 | + connect_direct, |
| 15 | +) |
| 16 | +from wherobots.db.errors import InterfaceError |
| 17 | + |
| 18 | + |
| 19 | +class TestCheckCancelled: |
| 20 | + def test_none_event_is_noop(self): |
| 21 | + _check_cancelled(None) |
| 22 | + |
| 23 | + def test_unset_event_is_noop(self): |
| 24 | + event = threading.Event() |
| 25 | + _check_cancelled(event) |
| 26 | + |
| 27 | + def test_set_event_raises(self): |
| 28 | + event = threading.Event() |
| 29 | + event.set() |
| 30 | + with pytest.raises(InterfaceError, match="cancelled by caller"): |
| 31 | + _check_cancelled(event) |
| 32 | + |
| 33 | + |
| 34 | +class TestConnectCancelEvent: |
| 35 | + @patch("wherobots.db.driver.requests.post") |
| 36 | + def test_cancel_before_post(self, mock_post): |
| 37 | + """cancel_event set before connect() should raise immediately without making HTTP calls.""" |
| 38 | + cancel = threading.Event() |
| 39 | + cancel.set() |
| 40 | + |
| 41 | + with pytest.raises(InterfaceError, match="cancelled by caller"): |
| 42 | + connect(api_key="test-key", cancel_event=cancel) |
| 43 | + |
| 44 | + mock_post.assert_not_called() |
| 45 | + |
| 46 | + @patch("wherobots.db.driver.requests.get") |
| 47 | + @patch("wherobots.db.driver.requests.post") |
| 48 | + def test_cancel_during_polling(self, mock_post, mock_get): |
| 49 | + """cancel_event set during session polling should abort the retry loop.""" |
| 50 | + # POST succeeds with redirect |
| 51 | + post_resp = MagicMock() |
| 52 | + post_resp.status_code = 200 |
| 53 | + post_resp.url = "https://api.example.com/sql/session/test-id" |
| 54 | + post_resp.raise_for_status = MagicMock() |
| 55 | + mock_post.return_value = post_resp |
| 56 | + |
| 57 | + # GET returns INITIALIZING (triggers TryAgain) |
| 58 | + get_resp = MagicMock() |
| 59 | + get_resp.status_code = 200 |
| 60 | + get_resp.raise_for_status = MagicMock() |
| 61 | + get_resp.json.return_value = {"status": "INITIALIZING"} |
| 62 | + mock_get.return_value = get_resp |
| 63 | + |
| 64 | + cancel = threading.Event() |
| 65 | + |
| 66 | + # Set cancel after a short delay (during polling) |
| 67 | + def set_cancel(): |
| 68 | + time.sleep(0.1) |
| 69 | + cancel.set() |
| 70 | + |
| 71 | + t = threading.Thread(target=set_cancel) |
| 72 | + t.start() |
| 73 | + |
| 74 | + with pytest.raises(InterfaceError, match="cancelled by caller"): |
| 75 | + connect(api_key="test-key", cancel_event=cancel, wait_timeout=10) |
| 76 | + |
| 77 | + t.join() |
| 78 | + |
| 79 | + @patch("wherobots.db.driver.requests.post") |
| 80 | + def test_http_timeout_on_post(self, mock_post): |
| 81 | + """requests.post should be called with a timeout.""" |
| 82 | + post_resp = MagicMock() |
| 83 | + post_resp.status_code = 401 |
| 84 | + post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp) |
| 85 | + post_resp.json.side_effect = requests.JSONDecodeError("", "", 0) |
| 86 | + mock_post.return_value = post_resp |
| 87 | + |
| 88 | + with pytest.raises(InterfaceError, match="Failed to create SQL session"): |
| 89 | + connect(api_key="test-key") |
| 90 | + |
| 91 | + _, kwargs = mock_post.call_args |
| 92 | + assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT |
| 93 | + |
| 94 | + @patch("wherobots.db.driver.requests.get") |
| 95 | + @patch("wherobots.db.driver.requests.post") |
| 96 | + def test_http_timeout_on_get(self, mock_post, mock_get): |
| 97 | + """requests.get in the polling loop should be called with a timeout.""" |
| 98 | + post_resp = MagicMock() |
| 99 | + post_resp.status_code = 200 |
| 100 | + post_resp.url = "https://api.example.com/sql/session/test-id" |
| 101 | + post_resp.raise_for_status = MagicMock() |
| 102 | + mock_post.return_value = post_resp |
| 103 | + |
| 104 | + get_resp = MagicMock() |
| 105 | + get_resp.status_code = 200 |
| 106 | + get_resp.raise_for_status = MagicMock() |
| 107 | + get_resp.json.return_value = { |
| 108 | + "status": "READY", |
| 109 | + "appMeta": {"url": "https://compute.example.com/sql/org/session-id"}, |
| 110 | + } |
| 111 | + mock_get.return_value = get_resp |
| 112 | + |
| 113 | + # Patch connect_direct to avoid actual WebSocket connection |
| 114 | + with patch("wherobots.db.driver.connect_direct") as mock_cd: |
| 115 | + mock_cd.return_value = MagicMock() |
| 116 | + connect(api_key="test-key") |
| 117 | + |
| 118 | + _, kwargs = mock_get.call_args |
| 119 | + assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT |
| 120 | + |
| 121 | + @patch("wherobots.db.driver.requests.post") |
| 122 | + def test_connect_without_cancel_event(self, mock_post): |
| 123 | + """connect() without cancel_event should work as before (backward compat).""" |
| 124 | + post_resp = MagicMock() |
| 125 | + post_resp.status_code = 401 |
| 126 | + post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp) |
| 127 | + post_resp.json.side_effect = requests.JSONDecodeError("", "", 0) |
| 128 | + mock_post.return_value = post_resp |
| 129 | + |
| 130 | + with pytest.raises(InterfaceError): |
| 131 | + connect(api_key="test-key") |
| 132 | + |
| 133 | + |
| 134 | +class TestConnectDirectCancelEvent: |
| 135 | + @patch("wherobots.db.driver.websockets.sync.client.connect") |
| 136 | + def test_cancel_before_ws_connect(self, mock_ws): |
| 137 | + cancel = threading.Event() |
| 138 | + cancel.set() |
| 139 | + |
| 140 | + with pytest.raises(InterfaceError, match="cancelled by caller"): |
| 141 | + connect_direct( |
| 142 | + uri="wss://compute.example.com/sql/org/session-id", |
| 143 | + cancel_event=cancel, |
| 144 | + ) |
| 145 | + |
| 146 | + mock_ws.assert_not_called() |
0 commit comments