Skip to content

Commit 3dc2289

Browse files
authored
codegen: Support pipe syntax for Union types (#1336)
From 3.14 onwards, we'll get `foo | bar` instead of `typing.Union[foo, bar]` as the annotation for union types (including optional). This PR prepares the codegen script for this.
1 parent b560ae8 commit 3dc2289

2 files changed

Lines changed: 174 additions & 10 deletions

File tree

libcst/codegen/gen_matcher_classes.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,109 @@
1616
OPTIONAL_RE = r"typing\.Union\[([^,]*?), NoneType]"
1717

1818

19+
class NormalizeUnions(cst.CSTTransformer):
20+
"""
21+
Convert a binary operation with | operators into a Union type.
22+
For example, converts `foo | bar | baz` into `typing.Union[foo, bar, baz]`.
23+
Special case: converts `foo | None` or `None | foo` into `typing.Optional[foo]`.
24+
Also flattens nested typing.Union types.
25+
"""
26+
27+
def leave_Subscript(
28+
self, original_node: cst.Subscript, updated_node: cst.Subscript
29+
) -> cst.Subscript:
30+
# Check if this is a typing.Union
31+
if (
32+
isinstance(updated_node.value, cst.Attribute)
33+
and isinstance(updated_node.value.value, cst.Name)
34+
and updated_node.value.attr.value == "Union"
35+
and updated_node.value.value.value == "typing"
36+
):
37+
# Collect all operands from any nested Unions
38+
operands: List[cst.BaseExpression] = []
39+
for slc in updated_node.slice:
40+
if not isinstance(slc.slice, cst.Index):
41+
continue
42+
value = slc.slice.value
43+
# If this is a nested Union, add its elements
44+
if (
45+
isinstance(value, cst.Subscript)
46+
and isinstance(value.value, cst.Attribute)
47+
and isinstance(value.value.value, cst.Name)
48+
and value.value.attr.value == "Union"
49+
and value.value.value.value == "typing"
50+
):
51+
operands.extend(
52+
nested_slc.slice.value
53+
for nested_slc in value.slice
54+
if isinstance(nested_slc.slice, cst.Index)
55+
)
56+
else:
57+
operands.append(value)
58+
59+
# flatten operands into a Union type
60+
return cst.Subscript(
61+
cst.Attribute(cst.Name("typing"), cst.Name("Union")),
62+
[cst.SubscriptElement(cst.Index(operand)) for operand in operands],
63+
)
64+
return updated_node
65+
66+
def leave_BinaryOperation(
67+
self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
68+
) -> Union[cst.BinaryOperation, cst.Subscript]:
69+
if not updated_node.operator.deep_equals(cst.BitOr()):
70+
return updated_node
71+
72+
def flatten_binary_op(node: cst.BaseExpression) -> List[cst.BaseExpression]:
73+
"""Flatten a binary operation tree into a list of operands."""
74+
if not isinstance(node, cst.BinaryOperation):
75+
# If it's a Union type, extract its elements
76+
if (
77+
isinstance(node, cst.Subscript)
78+
and isinstance(node.value, cst.Attribute)
79+
and isinstance(node.value.value, cst.Name)
80+
and node.value.attr.value == "Union"
81+
and node.value.value.value == "typing"
82+
):
83+
return [
84+
slc.slice.value
85+
for slc in node.slice
86+
if isinstance(slc.slice, cst.Index)
87+
]
88+
return [node]
89+
if not node.operator.deep_equals(cst.BitOr()):
90+
return [node]
91+
92+
left_operands = flatten_binary_op(node.left)
93+
right_operands = flatten_binary_op(node.right)
94+
return left_operands + right_operands
95+
96+
# Flatten the binary operation tree into a list of operands
97+
operands = flatten_binary_op(updated_node)
98+
99+
# Check for Optional case (None in union)
100+
none_count = sum(
101+
1 for op in operands if isinstance(op, cst.Name) and op.value == "None"
102+
)
103+
if none_count == 1 and len(operands) == 2:
104+
# This is an Optional case - find the non-None operand
105+
non_none = next(
106+
op
107+
for op in operands
108+
if not (isinstance(op, cst.Name) and op.value == "None")
109+
)
110+
return cst.Subscript(
111+
cst.Attribute(cst.Name("typing"), cst.Name("Optional")),
112+
[cst.SubscriptElement(cst.Index(non_none))],
113+
)
114+
115+
# Regular Union case
116+
return cst.Subscript(
117+
cst.Attribute(cst.Name("typing"), cst.Name("Union")),
118+
[cst.SubscriptElement(cst.Index(operand)) for operand in operands],
119+
)
120+
121+
19122
class CleanseFullTypeNames(cst.CSTTransformer):
20123
def leave_Call(
21124
self, original_node: cst.Call, updated_node: cst.Call
@@ -357,7 +460,9 @@ def _get_clean_type_from_subscript(
357460
elif isinstance(inner_type, (cst.Name, cst.SimpleString)):
358461
clean_inner_type = _get_clean_type_from_expression(aliases, inner_type)
359462
else:
360-
raise Exception("Logic error, unexpected type in Sequence!")
463+
raise Exception(
464+
f"Logic error, unexpected type in Sequence: {type(inner_type)}!"
465+
)
361466

362467
return _get_wrapped_union_type(
363468
typecst.deep_replace(inner_type, clean_inner_type),
@@ -386,9 +491,12 @@ def _get_clean_type_and_aliases(
386491
typestr = re.sub(OPTIONAL_RE, r"typing.Optional[\1]", typestr)
387492

388493
# Now, parse the expression with LibCST.
389-
cleanser = CleanseFullTypeNames()
494+
390495
typecst = parse_expression(typestr)
391-
typecst = typecst.visit(cleanser)
496+
typecst = typecst.visit(NormalizeUnions())
497+
assert isinstance(typecst, cst.BaseExpression)
498+
typecst = typecst.visit(CleanseFullTypeNames())
499+
assert isinstance(typecst, cst.BaseExpression)
392500
aliases: List[Alias] = []
393501

394502
# Now, convert the type to allow for MetadataMatchType and MatchIfTrue values.
@@ -397,7 +505,7 @@ def _get_clean_type_and_aliases(
397505
elif isinstance(typecst, (cst.Name, cst.SimpleString)):
398506
clean_type = _get_clean_type_from_expression(aliases, typecst)
399507
else:
400-
raise Exception("Logic error, unexpected top level type!")
508+
raise Exception(f"Logic error, unexpected top level type: {type(typecst)}!")
401509

402510
# Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
403511
# This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any

libcst/codegen/tests/test_codegen_clean.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import difflib
67
import os
78
import os.path
89

@@ -20,12 +21,20 @@ def assert_code_matches(
2021
new_code: str,
2122
module_name: str,
2223
) -> None:
23-
self.assertTrue(
24-
old_code == new_code,
25-
f"{module_name} needs new codegen, see "
26-
+ "`python -m libcst.codegen.generate --help` "
27-
+ "for instructions, or run `python -m libcst.codegen.generate all`",
28-
)
24+
if old_code != new_code:
25+
diff = difflib.unified_diff(
26+
old_code.splitlines(keepends=True),
27+
new_code.splitlines(keepends=True),
28+
fromfile="old_code",
29+
tofile="new_code",
30+
)
31+
diff_str = "".join(diff)
32+
self.fail(
33+
f"{module_name} needs new codegen, see "
34+
+ "`python -m libcst.codegen.generate --help` "
35+
+ "for instructions, or run `python -m libcst.codegen.generate all`. "
36+
+ f"Diff:\n{diff_str}"
37+
)
2938

3039
def test_codegen_clean_visitor_functions(self) -> None:
3140
"""
@@ -123,3 +132,50 @@ def test_codegen_clean_return_types(self) -> None:
123132

124133
# Now that we've done simple codegen, verify that it matches.
125134
self.assert_code_matches(old_code, new_code, "libcst.matchers._return_types")
135+
136+
def test_normalize_unions(self) -> None:
137+
"""
138+
Verifies that NormalizeUnions correctly converts binary operations with |
139+
into Union types, with special handling for Optional cases.
140+
"""
141+
import libcst as cst
142+
from libcst.codegen.gen_matcher_classes import NormalizeUnions
143+
144+
def assert_transforms_to(input_code: str, expected_code: str) -> None:
145+
input_cst = cst.parse_expression(input_code)
146+
expected_cst = cst.parse_expression(expected_code)
147+
148+
result = input_cst.visit(NormalizeUnions())
149+
assert isinstance(
150+
result, cst.BaseExpression
151+
), f"Expected BaseExpression, got {type(result)}"
152+
153+
result_code = cst.Module(body=()).code_for_node(result)
154+
expected_code_str = cst.Module(body=()).code_for_node(expected_cst)
155+
156+
self.assertEqual(
157+
result_code,
158+
expected_code_str,
159+
f"Expected {expected_code_str}, got {result_code}",
160+
)
161+
162+
# Test regular union case
163+
assert_transforms_to("foo | bar | baz", "typing.Union[foo, bar, baz]")
164+
165+
# Test Optional case (None on right)
166+
assert_transforms_to("foo | None", "typing.Optional[foo]")
167+
168+
# Test Optional case (None on left)
169+
assert_transforms_to("None | foo", "typing.Optional[foo]")
170+
171+
# Test case with more than 2 operands including None (should remain Union)
172+
assert_transforms_to("foo | bar | None", "typing.Union[foo, bar, None]")
173+
174+
# Flatten existing Union types
175+
assert_transforms_to(
176+
"typing.Union[foo, typing.Union[bar, baz]]", "typing.Union[foo, bar, baz]"
177+
)
178+
# Merge two kinds of union types
179+
assert_transforms_to(
180+
"foo | typing.Union[bar, baz]", "typing.Union[foo, bar, baz]"
181+
)

0 commit comments

Comments
 (0)