1616OPTIONAL_RE = r"typing\.Union\[([^,]*?), NoneType]"
1717
1818
19+ class NormalizeUnions (cst .CSTTransformer ):
20+ """
21+ Convert a binary operation with | operators into a Union type.
22+ For example, converts `foo | bar | baz` into `typing.Union[foo, bar, baz]`.
23+ Special case: converts `foo | None` or `None | foo` into `typing.Optional[foo]`.
24+ Also flattens nested typing.Union types.
25+ """
26+
27+ def leave_Subscript (
28+ self , original_node : cst .Subscript , updated_node : cst .Subscript
29+ ) -> cst .Subscript :
30+ # Check if this is a typing.Union
31+ if (
32+ isinstance (updated_node .value , cst .Attribute )
33+ and isinstance (updated_node .value .value , cst .Name )
34+ and updated_node .value .attr .value == "Union"
35+ and updated_node .value .value .value == "typing"
36+ ):
37+ # Collect all operands from any nested Unions
38+ operands : List [cst .BaseExpression ] = []
39+ for slc in updated_node .slice :
40+ if not isinstance (slc .slice , cst .Index ):
41+ continue
42+ value = slc .slice .value
43+ # If this is a nested Union, add its elements
44+ if (
45+ isinstance (value , cst .Subscript )
46+ and isinstance (value .value , cst .Attribute )
47+ and isinstance (value .value .value , cst .Name )
48+ and value .value .attr .value == "Union"
49+ and value .value .value .value == "typing"
50+ ):
51+ operands .extend (
52+ nested_slc .slice .value
53+ for nested_slc in value .slice
54+ if isinstance (nested_slc .slice , cst .Index )
55+ )
56+ else :
57+ operands .append (value )
58+
59+ # flatten operands into a Union type
60+ return cst .Subscript (
61+ cst .Attribute (cst .Name ("typing" ), cst .Name ("Union" )),
62+ [cst .SubscriptElement (cst .Index (operand )) for operand in operands ],
63+ )
64+ return updated_node
65+
66+ def leave_BinaryOperation (
67+ self , original_node : cst .BinaryOperation , updated_node : cst .BinaryOperation
68+ ) -> Union [cst .BinaryOperation , cst .Subscript ]:
69+ if not updated_node .operator .deep_equals (cst .BitOr ()):
70+ return updated_node
71+
72+ def flatten_binary_op (node : cst .BaseExpression ) -> List [cst .BaseExpression ]:
73+ """Flatten a binary operation tree into a list of operands."""
74+ if not isinstance (node , cst .BinaryOperation ):
75+ # If it's a Union type, extract its elements
76+ if (
77+ isinstance (node , cst .Subscript )
78+ and isinstance (node .value , cst .Attribute )
79+ and isinstance (node .value .value , cst .Name )
80+ and node .value .attr .value == "Union"
81+ and node .value .value .value == "typing"
82+ ):
83+ return [
84+ slc .slice .value
85+ for slc in node .slice
86+ if isinstance (slc .slice , cst .Index )
87+ ]
88+ return [node ]
89+ if not node .operator .deep_equals (cst .BitOr ()):
90+ return [node ]
91+
92+ left_operands = flatten_binary_op (node .left )
93+ right_operands = flatten_binary_op (node .right )
94+ return left_operands + right_operands
95+
96+ # Flatten the binary operation tree into a list of operands
97+ operands = flatten_binary_op (updated_node )
98+
99+ # Check for Optional case (None in union)
100+ none_count = sum (
101+ 1 for op in operands if isinstance (op , cst .Name ) and op .value == "None"
102+ )
103+ if none_count == 1 and len (operands ) == 2 :
104+ # This is an Optional case - find the non-None operand
105+ non_none = next (
106+ op
107+ for op in operands
108+ if not (isinstance (op , cst .Name ) and op .value == "None" )
109+ )
110+ return cst .Subscript (
111+ cst .Attribute (cst .Name ("typing" ), cst .Name ("Optional" )),
112+ [cst .SubscriptElement (cst .Index (non_none ))],
113+ )
114+
115+ # Regular Union case
116+ return cst .Subscript (
117+ cst .Attribute (cst .Name ("typing" ), cst .Name ("Union" )),
118+ [cst .SubscriptElement (cst .Index (operand )) for operand in operands ],
119+ )
120+
121+
19122class CleanseFullTypeNames (cst .CSTTransformer ):
20123 def leave_Call (
21124 self , original_node : cst .Call , updated_node : cst .Call
@@ -357,7 +460,9 @@ def _get_clean_type_from_subscript(
357460 elif isinstance (inner_type , (cst .Name , cst .SimpleString )):
358461 clean_inner_type = _get_clean_type_from_expression (aliases , inner_type )
359462 else :
360- raise Exception ("Logic error, unexpected type in Sequence!" )
463+ raise Exception (
464+ f"Logic error, unexpected type in Sequence: { type (inner_type )} !"
465+ )
361466
362467 return _get_wrapped_union_type (
363468 typecst .deep_replace (inner_type , clean_inner_type ),
@@ -386,9 +491,12 @@ def _get_clean_type_and_aliases(
386491 typestr = re .sub (OPTIONAL_RE , r"typing.Optional[\1]" , typestr )
387492
388493 # Now, parse the expression with LibCST.
389- cleanser = CleanseFullTypeNames ()
494+
390495 typecst = parse_expression (typestr )
391- typecst = typecst .visit (cleanser )
496+ typecst = typecst .visit (NormalizeUnions ())
497+ assert isinstance (typecst , cst .BaseExpression )
498+ typecst = typecst .visit (CleanseFullTypeNames ())
499+ assert isinstance (typecst , cst .BaseExpression )
392500 aliases : List [Alias ] = []
393501
394502 # Now, convert the type to allow for MetadataMatchType and MatchIfTrue values.
@@ -397,7 +505,7 @@ def _get_clean_type_and_aliases(
397505 elif isinstance (typecst , (cst .Name , cst .SimpleString )):
398506 clean_type = _get_clean_type_from_expression (aliases , typecst )
399507 else :
400- raise Exception ("Logic error, unexpected top level type!" )
508+ raise Exception (f "Logic error, unexpected top level type: { type ( typecst ) } !" )
401509
402510 # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
403511 # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
0 commit comments