Skip to content

Commit 4804c7f

Browse files
committed
add format for sound_like window_function
-s
1 parent c930736 commit 4804c7f

10 files changed

Lines changed: 5365 additions & 5323 deletions

File tree

sqlgpt_parser/format/formatter.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -421,24 +421,31 @@ def visit_list_expression(self, node, unmangle_names):
421421
else:
422422
return "(%s)" % self._join_expressions(node.values, unmangle_names)
423423

424+
def visit_window_func(self, node, unmangle_names):
425+
args = ", ".join([self.process(arg, unmangle_names) for arg in node.func_args])
426+
ignore_null = f" {node.ignore_null} NULLS" if node.ignore_null else ""
427+
window_spec = " OVER (" + self.process(node.window_spec, unmangle_names) + ")"
428+
return f"{node.func_name.upper()}({args}){ignore_null}{window_spec}"
429+
424430
def visit_window_spec(self, node, unmangle_names):
425431
parts = []
426-
427432
if node.partition_by:
428-
parts.append(
429-
"PARTITION BY "
430-
+ self._join_expressions(node.partition_by, unmangle_names)
431-
)
433+
self.process(node.partition_by, unmangle_names)
432434
if node.order_by:
433435
parts.append("ORDER BY " + format_sort_items(node.order_by, unmangle_names))
434-
if node.frame:
435-
parts.append(self.process(node.frame, unmangle_names))
436+
if node.frame_clause:
437+
parts.append(self.process(node.frame_clause, unmangle_names))
436438

437-
return '(' + ' '.join(parts) + ')'
439+
return ' '.join(parts)
438440

439-
def visit_window_frame(self, node, unmangle_names):
440-
ret = node.type + " "
441+
def visit_partition_by_clause(self, node, unmangle_names):
442+
return "PARTITION BY " + self._join_expressions(node.items, unmangle_names)
443+
444+
def visit_frame_clause(self, node, unmangle_names):
445+
return f"{node.type} {self.process(node.frame_range, unmangle_names)}"
441446

447+
def visit_window_frame(self, node, unmangle_names):
448+
ret = ""
442449
if node.end:
443450
ret += "BETWEEN %s AND %s" % (
444451
self.process(node.start, unmangle_names),
@@ -449,6 +456,19 @@ def visit_window_frame(self, node, unmangle_names):
449456

450457
return ret
451458

459+
def visit_frame_bound(self, node, unmangle_names):
460+
if node.type.upper() == "ROW":
461+
return "CURRENT ROW"
462+
expr = (
463+
self.process(node.expr, unmangle_names)
464+
if node.expr is not None
465+
else "UNBOUNDED "
466+
)
467+
return f"{expr} {node.type.upper()}"
468+
469+
def visit_frame_expr(self, node, unmangle_names):
470+
return self.process(node.value, unmangle_names)
471+
452472
def visit_single_column(self, node, indent):
453473
format_expression(node.expression)
454474

@@ -468,6 +488,9 @@ def visit_match_against_expression(self, node, unmangle_names):
468488
full_text_search_modifier = full_text_search_modifier.upper()
469489
return f"MATCH({columns}) AGAINST ({self.process(node.expr, unmangle_names)}{full_text_search_modifier})"
470490

491+
def visit_sound_like(self, node, unmangle_names):
492+
return f"{self.process(node.arguments[0])} SOUNDS LIKE {self.process(node.arguments[1])}"
493+
471494
def _format_binary_expression(self, operator, left, right, unmangle_names):
472495
return "%s %s %s" % (
473496
self.process(left, unmangle_names),
@@ -689,13 +712,14 @@ def visit_table_subquery(self, node, indent):
689712
return None
690713

691714
def visit_union(self, node, indent):
692-
all = node.all
693715
for i, relation in enumerate(node.relations):
694716
self._process_relation(relation, indent)
695717
self.builder.append("\n")
696718
if i != len(node.relations) - 1:
697-
if all:
719+
if node.all:
698720
self._append(indent, "UNION ALL")
721+
elif node.distinct:
722+
self._append(indent, "UNION DISTINCT")
699723
else:
700724
self._append(indent, "UNION")
701725
self.builder.append("\n")
@@ -704,7 +728,12 @@ def visit_union(self, node, indent):
704728

705729
def visit_except(self, node, indent):
706730
self._process_relation(node.left, indent)
707-
self.builder.append("EXCEPT " + "ALL " if not node.distinct else "")
731+
if node.all is not None:
732+
self._append(indent, "EXCEPT ALL")
733+
elif node.distinct is not None:
734+
self._append(indent, "EXCEPT DISTINCT")
735+
else:
736+
self._append(indent, "EXCEPT")
708737
self._process_relation(node.right, indent)
709738

710739
return None
@@ -756,7 +785,11 @@ def visit_intersect(self, node, indent):
756785
relations = [
757786
self._process_relation(relation, indent) for relation in node.relations
758787
]
759-
intersect = "INTERSECT " + "ALL " if not node.distinct else ""
788+
intersect = "INTERSECT"
789+
if node.all is not None:
790+
intersect += " ALL"
791+
elif node.distinct is not None:
792+
intersect += " DISTINCT"
760793
self.builder.append(intersect.join(relations))
761794
return None
762795

sqlgpt_parser/parser/mysql_parser/parser.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,7 +1570,7 @@ def p_window_func_call(p):
15701570
| ROW_NUMBER LPAREN RPAREN over_clause
15711571
"""
15721572
length = len(p)
1573-
window_spec = p[-1]
1573+
window_spec = p[length-1]
15741574
args = []
15751575
ignore_null = None
15761576

@@ -1711,7 +1711,10 @@ def p_frame_start(p):
17111711
| frame_expr PRECEDING
17121712
| frame_expr FOLLOWING
17131713
"""
1714-
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
1714+
if p.slice[1].type == 'frame_expr':
1715+
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
1716+
else:
1717+
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)
17151718

17161719

17171720
def p_frame_end(p):
@@ -1726,13 +1729,8 @@ def p_frame_between(p):
17261729

17271730
def p_frame_expr(p):
17281731
r"""frame_expr : figure
1729-
| QM
1730-
| INTERVAL expression time_unit
1731-
|"""
1732-
if len(p) == 4:
1733-
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[2], unit=p[3])
1734-
else:
1735-
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])
1732+
| time_interval"""
1733+
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])
17361734

17371735

17381736
def p_lead_lag_info_opt(p):

sqlgpt_parser/parser/mysql_parser/parser_table.py

Lines changed: 1483 additions & 1485 deletions
Large diffs are not rendered by default.

sqlgpt_parser/parser/oceanbase_parser/parser.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,7 @@ def p_window_func_call(p):
16491649
| ROW_NUMBER LPAREN RPAREN over_clause
16501650
"""
16511651
length = len(p)
1652-
window_spec = p[-1]
1652+
window_spec = p[length-1]
16531653
args = []
16541654
ignore_null = None
16551655

@@ -1790,7 +1790,10 @@ def p_frame_start(p):
17901790
| frame_expr PRECEDING
17911791
| frame_expr FOLLOWING
17921792
"""
1793-
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
1793+
if p.slice[1].type == 'frame_expr':
1794+
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
1795+
else:
1796+
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)
17941797

17951798

17961799
def p_frame_end(p):
@@ -1802,12 +1805,9 @@ def p_frame_between(p):
18021805
r"""frame_between : BETWEEN frame_start AND frame_end"""
18031806
p[0] = WindowFrame(p.lineno(1), p.lexpos(1), start=p[2], end=p[4])
18041807

1805-
18061808
def p_frame_expr(p):
18071809
r"""frame_expr : figure
1808-
| QM
1809-
| time_interval
1810-
|"""
1810+
| time_interval"""
18111811
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])
18121812

18131813

sqlgpt_parser/parser/oceanbase_parser/parser_table.py

Lines changed: 1894 additions & 1896 deletions
Large diffs are not rendered by default.

sqlgpt_parser/parser/odps_parser/parser.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,7 @@ def p_window_func_call(p):
16571657
| ROW_NUMBER LPAREN RPAREN over_clause
16581658
"""
16591659
length = len(p)
1660-
window_spec = p[-1]
1660+
window_spec = p[length-1]
16611661
args = []
16621662
ignore_null = None
16631663

@@ -1798,8 +1798,10 @@ def p_frame_start(p):
17981798
| frame_expr PRECEDING
17991799
| frame_expr FOLLOWING
18001800
"""
1801-
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
1802-
1801+
if p.slice[1].type == 'frame_expr':
1802+
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
1803+
else:
1804+
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)
18031805

18041806
def p_frame_end(p):
18051807
r"""frame_end : frame_start"""
@@ -1810,12 +1812,9 @@ def p_frame_between(p):
18101812
r"""frame_between : BETWEEN frame_start AND frame_end"""
18111813
p[0] = WindowFrame(p.lineno(1), p.lexpos(1), start=p[2], end=p[4])
18121814

1813-
18141815
def p_frame_expr(p):
18151816
r"""frame_expr : figure
1816-
| QM
1817-
| time_interval
1818-
|"""
1817+
| time_interval"""
18191818
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])
18201819

18211820

sqlgpt_parser/parser/odps_parser/parser_table.py

Lines changed: 1898 additions & 1900 deletions
Large diffs are not rendered by default.

sqlgpt_parser/parser/tree/window.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self.window_spec = window_spec
3131

3232
def accept(self, visitor, context):
33-
return super().accept(visitor, context)
33+
return visitor.visit_window_func(self, context)
3434

3535

3636
class WindowSpec(Node):
@@ -64,9 +64,8 @@ def accept(self, visitor, context):
6464

6565

6666
class WindowFrame(Node):
67-
def __init__(self, line=None, pos=None, type=None, start=None, end=None):
67+
def __init__(self, line=None, pos=None, start=None, end=None):
6868
super(WindowFrame, self).__init__(line, pos)
69-
self.type = type
7069
self.start = start
7170
self.end = end
7271

@@ -85,7 +84,9 @@ def accept(self, visitor, context):
8584

8685

8786
class FrameExpr(Node):
88-
def __init__(self, line=None, pos=None, value=None, unit=None):
87+
def __init__(self, line=None, pos=None, value=None):
8988
super(FrameExpr, self).__init__(line, pos)
9089
self.value = value
91-
self.unit = unit
90+
91+
def accept(self, visitor, context):
92+
return visitor.visit_frame_expr(self, context)

test/format/test_sql_formatter.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,23 @@ def test_time_interval(self):
197197
after_sql_rewrite_format = format_sql(statement, 0)
198198
assert after_sql_rewrite_format == except_sql
199199

200+
def test_windows_function(self):
201+
test_sqls_except = {
202+
"SELECT first_value(value) OVER (PARTITION BY id ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND "
203+
"UNBOUNDED FOLLOWING) AS first_val FROM my_table": "SELECT\n"
204+
" FIRST_VALUE(value) OVER (ORDER BY date ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS first_val"
205+
"\nFROM\n my_table",
206+
"SELECT first_value(value) OVER (PARTITION BY id ORDER BY date RANGE UNBOUNDED PRECEDING) "
207+
"AS first_val FROM my_table": "SELECT"
208+
"\n FIRST_VALUE(value) OVER (ORDER BY date ASC RANGE UNBOUNDED PRECEDING) AS first_val"
209+
"\nFROM"
210+
"\n my_table",
211+
}
212+
for sql, except_sql in test_sqls_except.items():
213+
statement = parser.parse(sql)
214+
after_sql_rewrite_format = format_sql(statement, 0)
215+
assert after_sql_rewrite_format == except_sql
216+
200217

201218
if __name__ == '__main__':
202219
unittest.main()

test/parser/test_parser_dml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def test_with_operation(self):
250250
result = oceanbase_parser.parse(sql)
251251
assert isinstance(result, WithHasQuery)
252252

253-
def test_windows_func(self):
253+
def test_window_func(self):
254254
test_sqls = [
255255
"""
256256
SELECT

0 commit comments

Comments
 (0)