-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsmoke.py
More file actions
110 lines (98 loc) · 3.73 KB
/
smoke.py
File metadata and controls
110 lines (98 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# A simple smoke test for the DB driver.
import argparse
import concurrent.futures
import functools
import logging
import sys
import pandas
from rich.console import Console
from rich.table import Table
from wherobots.db import connect, connect_direct
from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE
from wherobots.db.connection import Connection
from wherobots.db.region import Region
from wherobots.db.session_type import SessionType
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-key-file", help="File containing the API key")
parser.add_argument("--token-file", help="File containing the token")
parser.add_argument("--region", help="Region to connect to (ie. aws-us-west-2)")
parser.add_argument("--version", help="Runtime version (ie. latest)")
parser.add_argument(
"--session-type",
help="Type of session to create",
default=DEFAULT_SESSION_TYPE,
choices=[st.value for st in SessionType],
)
parser.add_argument(
"--debug",
help="Enable debug logging",
action="store_const",
const=logging.DEBUG,
default=logging.INFO,
)
parser.add_argument(
"--api-endpoint",
help="Wherobots API endpoint to request a SQL session from",
default=DEFAULT_ENDPOINT,
)
parser.add_argument("--ws-url", help="Direct URL to connect to")
parser.add_argument(
"--shutdown-after-inactive-seconds",
help="Request a specific SQL Session expiration timeout (in seconds)",
)
parser.add_argument(
"--wide", help="Enable wide output", action="store_const", const=80, default=30
)
parser.add_argument("sql", nargs="+", help="SQL query to execute")
args = parser.parse_args()
logging.basicConfig(
stream=sys.stdout,
level=args.debug,
format="%(asctime)s.%(msecs)03d %(levelname)s %(name)20s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logging.getLogger("websockets.protocol").setLevel(args.debug)
api_key = None
token = None
headers = None
if args.api_key_file:
with open(args.api_key_file) as f:
api_key = f.read().strip()
headers = {"X-API-Key": api_key}
if args.token_file:
with open(args.token_file) as f:
token = f.read().strip()
headers = {"Authorization": f"Bearer {token}"}
if args.ws_url:
conn_func = functools.partial(connect_direct, uri=args.ws_url, headers=headers)
else:
conn_func = functools.partial(
connect,
host=args.api_endpoint,
token=token,
api_key=api_key,
shutdown_after_inactive_seconds=args.shutdown_after_inactive_seconds,
wait_timeout=900,
region=Region(args.region) if args.region else Region.AWS_US_WEST_2,
version=args.version,
session_type=SessionType(args.session_type),
)
def render(results: pandas.DataFrame) -> None:
table = Table()
table.add_column("#")
for column in results.columns:
table.add_column(column, max_width=args.wide, no_wrap=True)
for row in results.itertuples(name=None):
r = [str(x) for x in row]
table.add_row(*r)
Console().print(table)
def execute(conn: Connection, sql: str) -> pandas.DataFrame:
with conn.cursor() as cursor:
cursor.execute(sql)
return cursor.fetchall()
with conn_func() as conn:
with concurrent.futures.ThreadPoolExecutor() as pool:
futures = [pool.submit(execute, conn, s) for s in args.sql]
for future in concurrent.futures.as_completed(futures):
render(future.result())