-
Notifications
You must be signed in to change notification settings - Fork 109
Expand file tree
/
Copy pathsqlserver_adapter.py
More file actions
290 lines (256 loc) · 10.7 KB
/
sqlserver_adapter.py
File metadata and controls
290 lines (256 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from typing import List, Optional
import agate
import dbt_common.exceptions
from dbt.adapters.base.column import Column as BaseColumn
from dbt.adapters.base.impl import ConstraintSupport
from dbt.adapters.base.meta import available
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support
from dbt.adapters.events.types import SchemaCreation
from dbt.adapters.reference_keys import _make_ref_key_dict
from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME, SQLAdapter
from dbt_common.behavior_flags import BehaviorFlag
from dbt_common.contracts.constraints import (
ColumnLevelConstraint,
ConstraintType,
ModelLevelConstraint,
)
from dbt_common.events.functions import fire_event
from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn
from dbt.adapters.sqlserver.sqlserver_configs import SQLServerConfigs
from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager
from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation
class SQLServerAdapter(SQLAdapter):
"""
Controls actual implmentation of adapter, and ability to override certain methods.
"""
ConnectionManager = SQLServerConnectionManager
Column = SQLServerColumn
AdapterSpecificConfigs = SQLServerConfigs
Relation = SQLServerRelation
_capabilities: CapabilityDict = CapabilityDict(
{
Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full),
Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full),
}
)
CONSTRAINT_SUPPORT = {
ConstraintType.check: ConstraintSupport.ENFORCED,
ConstraintType.not_null: ConstraintSupport.ENFORCED,
ConstraintType.unique: ConstraintSupport.ENFORCED,
ConstraintType.primary_key: ConstraintSupport.ENFORCED,
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}
@property
def _behavior_flags(self) -> List[BehaviorFlag]:
return [
{
"name": "empty",
"default": False,
"description": (
"When enabled, table and view materializations will be created as empty "
"structures (no data)."
),
},
{
"name": "dbt_sqlserver_use_default_schema_concat",
"default": False,
"description": (
"When True, uses dbt-core's standard schema concatenation "
"(`target.schema` + `_` + `custom_schema_name`). "
"When False (default), uses legacy adapter behaviour: "
"`custom_schema_name` is used directly without prefixing `target.schema`. "
"For a permanent solution, override the `sqlserver__generate_schema_name` "
"macro in your project instead."
),
},
]
@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
"""Get a list of the Columns with names and data types from the given sql."""
_, cursor = self.connections.add_select_query(sql)
columns = [
self.Column.create(
column_name, self.connections.data_type_code_to_name(column_type_code)
)
# https://peps.python.org/pep-0249/#description
for column_name, column_type_code, *_ in cursor.description
]
return columns
@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
return "bit"
@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
return "datetime2(6)"
@classmethod
def convert_number_type(cls, agate_table, col_idx):
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float" if decimals else "int"
def create_schema(self, relation: BaseRelation) -> None:
relation = relation.without_identifier()
fire_event(SchemaCreation(relation=_make_ref_key_dict(relation)))
macro_name = CREATE_SCHEMA_MACRO_NAME
kwargs = {
"relation": relation,
}
if self.config.credentials.schema_authorization:
kwargs["schema_authorization"] = self.config.credentials.schema_authorization
macro_name = "sqlserver__create_schema_with_authorization"
self.execute_macro(macro_name, kwargs=kwargs)
self.commit_if_has_connection()
@classmethod
def convert_text_type(cls, agate_table, col_idx):
column = agate_table.columns[col_idx]
# see https://github.com/fishtown-analytics/dbt/pull/2255
lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()]
max_len = max(lens) if lens else 64
length = max_len if max_len > 16 else 16
return "varchar({})".format(length)
@classmethod
def convert_time_type(cls, agate_table, col_idx):
return "time(6)"
@classmethod
def date_function(cls):
return "getdate()"
# Methods used in adapter tests
def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
# note: 'interval' is not supported for T-SQL
# for backwards compatibility, we're compelled to set some sort of
# default. A lot of searching has lead me to believe that the
# '+ interval' syntax used in postgres/redshift is relatively common
# and might even be the SQL standard's intention.
return f"DATEADD({interval},{number},{add_to})"
def string_add_sql(
self,
add_to: str,
value: str,
location="append",
) -> str:
"""
`+` is T-SQL's string concatenation operator
"""
if location == "append":
return f"{add_to} + '{value}'"
elif location == "prepend":
return f"'{value}' + {add_to}"
else:
raise ValueError(f'Got an unexpected location value of "{location}"')
def get_rows_different_sql(
self,
relation_a: BaseRelation,
relation_b: BaseRelation,
column_names: Optional[List[str]] = None,
except_operator: str = "EXCEPT",
) -> str:
"""
note: using is not supported on Synapse so COLUMNS_EQUAL_SQL is adjsuted
Generate SQL for a query that returns a single row with a two
columns: the number of rows that are different between the two
relations and the number of mismatched rows.
"""
# This method only really exists for test reasons.
names: List[str]
if column_names is None:
columns = self.get_columns_in_relation(relation_a)
names = sorted((self.quote(c.name) for c in columns))
else:
names = sorted((self.quote(n) for n in column_names))
columns_csv = ", ".join(names)
if columns_csv == "":
columns_csv = "*"
sql = COLUMNS_EQUAL_SQL.format(
columns=columns_csv,
relation_a=str(relation_a),
relation_b=str(relation_b),
except_op=except_operator,
)
return sql
def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Not used to validate custom strategies defined by end users.
"""
return ["append", "delete+insert", "merge", "microbatch"]
# This is for use in the test suite
def run_sql_for_tests(self, sql, fetch, conn):
cursor = conn.handle.cursor()
try:
cursor.execute(sql)
if not fetch:
conn.handle.commit()
if fetch == "one":
return cursor.fetchone()
elif fetch == "all":
return cursor.fetchall()
else:
return
except BaseException:
if conn.handle and not getattr(conn.handle, "closed", True):
conn.handle.rollback()
raise
finally:
conn.transaction_open = False
@available
@classmethod
def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional[str]:
rendered_column_constraint = None
if constraint.type == ConstraintType.not_null:
rendered_column_constraint = "not null "
else:
rendered_column_constraint = ""
if rendered_column_constraint:
rendered_column_constraint = rendered_column_constraint.strip()
return rendered_column_constraint
@classmethod
def render_model_constraint(cls, constraint: ModelLevelConstraint) -> Optional[str]:
constraint_prefix = "add constraint "
column_list = ", ".join(constraint.columns)
if constraint.name is None:
raise dbt_common.exceptions.DbtDatabaseError(
"Constraint name cannot be empty. Provide constraint name - column "
+ column_list
+ " and run the project again."
)
if constraint.type == ConstraintType.unique:
return constraint_prefix + f"{constraint.name} unique nonclustered({column_list})"
elif constraint.type == ConstraintType.primary_key:
return constraint_prefix + f"{constraint.name} primary key nonclustered({column_list})"
elif constraint.type == ConstraintType.foreign_key and constraint.expression:
return (
constraint_prefix
+ f"{constraint.name} foreign key({column_list}) references "
+ constraint.expression
)
elif constraint.type == ConstraintType.check and constraint.expression:
return f"{constraint_prefix} {constraint.name} check ({constraint.expression})"
elif constraint.type == ConstraintType.custom and constraint.expression:
return f"{constraint_prefix} {constraint.name} {constraint.expression}"
else:
return None
COLUMNS_EQUAL_SQL = """
with diff_count as (
SELECT
1 as id,
COUNT(*) as num_missing FROM (
(SELECT {columns} FROM {relation_a} {except_op}
SELECT {columns} FROM {relation_b})
UNION ALL
(SELECT {columns} FROM {relation_b} {except_op}
SELECT {columns} FROM {relation_a})
) as a
), table_a as (
SELECT COUNT(*) as num_rows FROM {relation_a}
), table_b as (
SELECT COUNT(*) as num_rows FROM {relation_b}
), row_count_diff as (
select
1 as id,
table_a.num_rows - table_b.num_rows as difference
from table_a, table_b
)
select
row_count_diff.difference as row_count_difference,
diff_count.num_missing as num_mismatched
from row_count_diff
join diff_count on row_count_diff.id = diff_count.id
""".strip()