Skip to content

Commit 8374ef0

Browse files
committed
minor cleanup
1 parent 286d8ec commit 8374ef0

5 files changed

Lines changed: 35 additions & 32 deletions

File tree

sql_metadata/column_extractor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,7 @@ def _is_date_part_unit(node: exp.Column) -> bool:
148148

149149

150150
class _Collector:
151-
"""Mutable accumulator for metadata gathered during the AST walk.
152-
153-
:param table_aliases: Pre-computed table alias → real name mapping.
154-
"""
151+
"""Mutable accumulator for metadata gathered during the AST walk."""
155152

156153
__slots__ = (
157154
"columns",

sql_metadata/dialect_parser.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,6 @@ def _parse_with_dialect(clean_sql: str, dialect: Any) -> exp.Expression | None:
254254

255255
if not results or results[0] is None:
256256
return None
257-
# sqlglot.parse's stub returns list[Expression | None]; the None case
258-
# is filtered one line above but mypy does not narrow through the
259-
# indexed access.
260257
return results[0] # type: ignore[return-value]
261258

262259
# -- quality checks -----------------------------------------------------

sql_metadata/nested_resolver.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,10 @@ def not_sql(self, expression: exp.Expression) -> str:
122122
return f"{self.sql(child, 'this')} IS NOT NULL"
123123
if isinstance(child, exp.In):
124124
return f"{self.sql(child, 'this')} NOT IN ({self.expressions(child)})"
125-
# sqlglot's Generator.not_sql is typed to take exp.Not; we widen the
126-
# parameter to exp.Expression to match the override signature across
127-
# all custom *_sql methods, and sqlglot's return type is inferred as
128-
# Any from partially-typed stubs.
125+
# sqlglot stubs under-type not_sql's parameter and return type.
129126
return super().not_sql(expression) # type: ignore[arg-type, no-any-return]
130127

131128

132-
133129
# ---------------------------------------------------------------------------
134130
# Resolution helpers
135131
# ---------------------------------------------------------------------------

sql_metadata/sql_cleaner.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
outer-parenthesis removal.
77
"""
88

9-
import itertools
109
import re
1110
from typing import NamedTuple
1211

@@ -36,34 +35,39 @@ def _is_wrapped(text: str) -> bool:
3635
"""
3736
if len(text) < 2 or text[0] != "(" or text[-1] != ")":
3837
return False
39-
inner = text[1:-1]
40-
depths = list(
41-
itertools.accumulate(
42-
(1 if c == "(" else -1 if c == ")" else 0) for c in inner
43-
)
44-
)
45-
return not depths or min(depths) >= 0
46-
47-
48-
def _strip_outer_parens(sql: str, _depth: int = 0) -> str:
38+
depth = 0
39+
for c in text[1:-1]:
40+
if c == "(":
41+
depth += 1
42+
elif c == ")":
43+
depth -= 1
44+
if depth < 0:
45+
return False
46+
return True
47+
48+
49+
def _strip_outer_parens(sql: str) -> str:
4950
"""Strip redundant outer parentheses from *sql*.
5051
5152
Needed because sqlglot cannot parse double-wrapped non-SELECT
52-
statements like ``((UPDATE ...))``. Uses ``itertools.accumulate``
53-
to verify balanced parens in one pass, with recursion for nesting.
54-
A depth guard prevents stack overflow on pathological input.
53+
statements like ``((UPDATE ...))``. A depth guard prevents stack
54+
overflow on pathological input.
5555
5656
:param sql: SQL string that may be wrapped in outer parentheses.
5757
:type sql: str
5858
:returns: The unwrapped SQL string.
5959
:rtype: str
6060
"""
61-
if _depth > 100:
62-
return sql
63-
s = sql.strip()
64-
if _is_wrapped(s):
65-
return _strip_outer_parens(s[1:-1].strip(), _depth + 1)
66-
return s
61+
62+
def _recur(s: str, depth: int) -> str:
63+
if depth > 100:
64+
return s
65+
s = s.strip()
66+
if _is_wrapped(s):
67+
return _recur(s[1:-1], depth + 1)
68+
return s
69+
70+
return _recur(sql, 0)
6771

6872

6973
def _normalize_cte_names(sql: str) -> tuple[str, dict[str, str]]:

test/test_edge_cases.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,12 @@ def test_strip_outer_parens_depth_guard():
5656
# RecursionError.
5757
parser = Parser("(" * 150 + "SELECT 1" + ")" * 150)
5858
assert parser.columns == []
59+
60+
61+
def test_strip_outer_parens_unbalanced_middle():
62+
"""Queries that look paren-wrapped but aren't (UNION of parenthesised SELECTs)."""
63+
# "(SELECT ...) UNION (SELECT ...)" starts with "(" and ends with ")" but the
64+
# inner parens go negative — _is_wrapped must short-circuit and leave the SQL
65+
# intact so both SELECT branches parse.
66+
parser = Parser("(SELECT a FROM t1) UNION (SELECT b FROM t2)")
67+
assert parser.tables == ["t1", "t2"]

0 commit comments

Comments
 (0)