-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathdriver.py
More file actions
212 lines (189 loc) · 6.84 KB
/
driver.py
File metadata and controls
212 lines (189 loc) · 6.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""Wherobots DB driver.
A PEP-0249 compatible driver for interfacing with Wherobots DB.
"""
import ssl
from importlib import metadata
from importlib.metadata import PackageNotFoundError
import logging
from packaging.version import Version
import platform
import requests
import tenacity
from typing import Final, Union, Dict
import urllib.parse
import websockets.sync.client
import certifi
from .connection import Connection
from .constants import (
DEFAULT_ENDPOINT,
DEFAULT_REGION,
DEFAULT_RUNTIME,
DEFAULT_READ_TIMEOUT_SECONDS,
DEFAULT_SESSION_TYPE,
DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
DEFAULT_VERSION,
MAX_MESSAGE_SIZE,
PARAM_STYLE,
PROTOCOL_VERSION,
AppStatus,
DataCompression,
GeometryRepresentation,
ResultsFormat,
SessionType,
)
from .errors import (
InterfaceError,
OperationalError,
)
from .region import Region
from .runtime import Runtime
apilevel = "2.0"
threadsafety = 1
paramstyle: Final[str] = PARAM_STYLE
def gen_user_agent_header():
try:
package_version = metadata.version("wherobots-python-dbapi")
except PackageNotFoundError:
package_version = "unknown"
python_version = platform.python_version()
system = platform.system().lower()
return {
"User-Agent": f"wherobots-python-dbapi/{package_version} os/{system} python/{python_version}"
}
def connect(
host: str = DEFAULT_ENDPOINT,
token: Union[str, None] = None,
api_key: Union[str, None] = None,
runtime: Union[Runtime, None] = None,
region: Union[Region, None] = None,
version: Union[str, None] = None,
wait_timeout: float = DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
read_timeout: float = DEFAULT_READ_TIMEOUT_SECONDS,
session_type: Union[SessionType, None] = None,
force_new: bool = False,
shutdown_after_inactive_seconds: Union[int, None] = None,
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
) -> Connection:
if not token and not api_key:
raise ValueError("At least one of `token` or `api_key` is required")
if token and api_key:
raise ValueError("`token` and `api_key` can't be both provided")
headers = gen_user_agent_header()
if token:
headers["Authorization"] = f"Bearer {token}"
elif api_key:
headers["X-API-Key"] = api_key
host = host or DEFAULT_ENDPOINT
runtime = runtime or DEFAULT_RUNTIME
region = region or DEFAULT_REGION
version = version or DEFAULT_VERSION
session_type = session_type or DEFAULT_SESSION_TYPE
logging.info(
"Requesting %s%s runtime running %s in %s from %s ...",
"new " if force_new else "",
runtime.value,
version,
region.value,
host,
)
# Default to HTTPS if the hostname doesn't explicitly specify a scheme.
if not host.startswith("http:"):
host = f"https://{host}"
try:
resp = requests.post(
url=f"{host}/sql/session",
params={"region": region.value, "force_new": force_new},
json={
"runtimeId": runtime.value,
"shutdownAfterInactiveSeconds": shutdown_after_inactive_seconds,
"version": version,
"sessionType": session_type.value,
},
headers=headers,
)
resp.raise_for_status()
except requests.HTTPError as e:
details = str(e)
try:
info = e.response.json()
errors = info.get("errors", [])
if errors and isinstance(errors, list):
details = f"{errors[0]['message']}: {errors[0]['details']}"
except requests.JSONDecodeError:
pass
raise InterfaceError(f"Failed to create SQL session: {details}") from e
# At this point we've been redirected to /sql/session/{session_id}, which we'll need to keep polling until the
# session is in READY state.
session_id_url = resp.url
@tenacity.retry(
stop=tenacity.stop_after_delay(wait_timeout),
wait=tenacity.wait_exponential(multiplier=1, min=1, max=5),
retry=tenacity.retry_if_not_exception_type(
(requests.HTTPError, OperationalError)
),
)
def get_session_uri() -> str:
r = requests.get(session_id_url, headers=headers)
r.raise_for_status()
payload = r.json()
status = AppStatus(payload.get("status"))
logging.info(" ... %s", status)
if status.is_starting():
raise tenacity.TryAgain("SQL Session is not ready yet")
elif status == AppStatus.READY:
return payload["appMeta"]["url"]
else:
logging.error("SQL session creation failed: %s; should not retry.", status)
raise OperationalError(f"Failed to create SQL session: {status}")
try:
logging.info("Getting SQL session status from %s ...", session_id_url)
session_uri = get_session_uri()
logging.debug("SQL session URI from app status: %s", session_uri)
except Exception as e:
raise InterfaceError("Could not acquire SQL session!", e)
return connect_direct(
uri=http_to_ws(session_uri),
headers=headers,
read_timeout=read_timeout,
results_format=results_format,
data_compression=data_compression,
geometry_representation=geometry_representation,
)
def http_to_ws(uri: str) -> str:
"""Converts an HTTP URI to a WebSocket URI."""
parsed = urllib.parse.urlparse(uri)
for from_scheme, to_scheme in [("http", "ws"), ("https", "wss")]:
if parsed.scheme == from_scheme:
parsed = parsed._replace(scheme=to_scheme)
return str(urllib.parse.urlunparse(parsed))
def connect_direct(
uri: str,
protocol: Version = PROTOCOL_VERSION,
headers: Union[Dict[str, str], None] = None,
read_timeout: float = DEFAULT_READ_TIMEOUT_SECONDS,
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
) -> Connection:
uri_with_protocol = f"{uri}/{protocol}"
try:
logging.info("Connecting to SQL session at %s ...", uri_with_protocol)
ssl_context = ssl.create_default_context()
ssl_context.load_verify_locations(certifi.where())
ws = websockets.sync.client.connect(
uri=uri_with_protocol,
additional_headers=headers,
max_size=MAX_MESSAGE_SIZE,
ssl=ssl_context,
)
except Exception as e:
raise InterfaceError("Failed to connect to SQL session!") from e
return Connection(
ws,
read_timeout=read_timeout,
results_format=results_format,
data_compression=data_compression,
geometry_representation=geometry_representation,
)