-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsmoke.py
More file actions
154 lines (139 loc) · 4.92 KB
/
smoke.py
File metadata and controls
154 lines (139 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# A simple smoke test for the DB driver.
import argparse
import concurrent.futures
import functools
import logging
import sys
import pandas
from rich.console import Console
from rich.table import Table
from wherobots.db import connect, connect_direct
from wherobots.db.constants import (
DEFAULT_ENDPOINT,
DEFAULT_SESSION_TYPE,
DEFAULT_STORAGE_FORMAT,
)
from wherobots.db.connection import Connection
from wherobots.db.region import Region
from wherobots.db.session_type import SessionType
from wherobots.db.result_storage import StorageFormat, Store
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-key-file", help="File containing the API key")
parser.add_argument("--token-file", help="File containing the token")
parser.add_argument("--region", help="Region to connect to (ie. aws-us-west-2)")
parser.add_argument("--version", help="Runtime version (ie. latest)")
parser.add_argument(
"--session-type",
help="Type of session to create",
default=DEFAULT_SESSION_TYPE,
choices=[st.value for st in SessionType],
)
parser.add_argument(
"--debug",
help="Enable debug logging",
action="store_const",
const=logging.DEBUG,
default=logging.INFO,
)
parser.add_argument(
"--api-endpoint",
help="Wherobots API endpoint to request a SQL session from",
default=DEFAULT_ENDPOINT,
)
parser.add_argument("--ws-url", help="Direct URL to connect to")
parser.add_argument(
"--shutdown-after-inactive-seconds",
help="Request a specific SQL Session expiration timeout (in seconds)",
)
parser.add_argument(
"--wide", help="Enable wide output", action="store_const", const=80, default=30
)
parser.add_argument(
"-s",
"--store",
help="Store results in temporary storage",
action="store_true",
)
parser.add_argument("sql", nargs="+", help="SQL query to execute")
args, unknown = parser.parse_known_args()
if args.store:
parser.add_argument(
"-sf",
"--storage-format",
help="Storage format for the results",
default=DEFAULT_STORAGE_FORMAT,
choices=[sf.value for sf in StorageFormat],
)
parser.add_argument(
"--single",
help="Generate only a single part file",
action="store_true",
)
parser.add_argument(
"-p",
"--presigned-url",
help="Generate a presigned URL for the results (only when --single is set)",
action="store_true",
)
args = parser.parse_args()
logging.basicConfig(
stream=sys.stdout,
level=args.debug,
format="%(asctime)s.%(msecs)03d %(levelname)s %(name)20s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logging.getLogger("websockets.protocol").setLevel(args.debug)
api_key = None
token = None
headers = None
if args.api_key_file:
with open(args.api_key_file) as f:
api_key = f.read().strip()
headers = {"X-API-Key": api_key}
if args.token_file:
with open(args.token_file) as f:
token = f.read().strip()
headers = {"Authorization": f"Bearer {token}"}
store = None
if args.store:
store = Store(
format=StorageFormat(args.storage_format)
if args.storage_format
else DEFAULT_STORAGE_FORMAT,
single=args.single,
generate_presigned_url=args.presigned_url,
)
if args.ws_url:
conn_func = functools.partial(connect_direct, uri=args.ws_url, headers=headers)
else:
conn_func = functools.partial(
connect,
host=args.api_endpoint,
token=token,
api_key=api_key,
shutdown_after_inactive_seconds=args.shutdown_after_inactive_seconds,
wait_timeout=900,
region=Region(args.region) if args.region else Region.AWS_US_WEST_2,
version=args.version,
session_type=SessionType(args.session_type),
store=store,
)
def render(results: pandas.DataFrame) -> None:
table = Table()
table.add_column("#")
for column in results.columns:
table.add_column(column, max_width=args.wide, no_wrap=True)
for row in results.itertuples(name=None):
r = [str(x) for x in row]
table.add_row(*r)
Console().print(table)
def execute(conn: Connection, sql: str) -> pandas.DataFrame:
with conn.cursor() as cursor:
cursor.execute(sql)
return cursor.fetchall()
with conn_func() as conn:
with concurrent.futures.ThreadPoolExecutor() as pool:
futures = [pool.submit(execute, conn, s) for s in args.sql]
for future in concurrent.futures.as_completed(futures):
render(future.result())