1+ import code_tokenize as ct
2+
3+ from collections import defaultdict
4+
5+ # AST Node ----------------------------------------------------------------
6+
7+
8+ class ASTNode (object ):
9+
10+ def __init__ (self , type , text = None , parent = None , children = None ):
11+
12+ # Basic node attributes
13+ self .type = type
14+ self .children = children if children is not None else []
15+ self .parent = parent
16+ self .text = text # If text is not None, then leaf node
17+
18+ # Tree based attributes
19+ self .subtree_hash = None
20+ self .subtree_height = 0
21+ self .subtree_weight = 1
22+
23+ def isomorph (self , other ):
24+ return ((self .subtree_hash , self .type , self .subtree_height , self .subtree_weight ) ==
25+ (other .subtree_hash , other .type , other .subtree_height , other .subtree_weight ))
26+
27+
28+ def sexp (self ):
29+ name = self .text if self .text is not None else self .type
30+
31+ child_sexp = []
32+ for child in self .children :
33+ text = child .sexp ()
34+ text = [" " + t for t in text .splitlines ()]
35+ child_sexp .append ("\n " .join (text ))
36+
37+ if len (child_sexp ) == 0 :
38+ return name
39+
40+ return "%s {\n %s\n }" % (name , " " .join (child_sexp ))
41+
42+ def __repr__ (self ):
43+ attrs = {"type" : self .type , "text" : self .text }
44+ return "ASTNode(%s)" % (", " .join (["%s=%s" % (k , v ) for k , v in attrs .items () if v is not None ]))
45+
46+
47+ def default_create_node (type , children , text = None ):
48+ new_node = ASTNode (type , text = text , children = children )
49+
50+ # Subtree metrics
51+ height = 0
52+ weight = 1
53+ hash_str = []
54+
55+ for child in children :
56+ child .parent = new_node # Set parent relation
57+ height = max (child .subtree_height + 1 , height )
58+ weight += child .subtree_weight
59+ hash_str .append (str (child .subtree_hash ))
60+
61+ new_node .subtree_height = height
62+ new_node .subtree_weight = weight
63+
64+ # WL hash subtree representation
65+ base_str = new_node .type if new_node .text is None else new_node .text
66+ hash_str .insert (0 , base_str )
67+ hash_str = "_" .join (hash_str )
68+ new_node .subtree_hash = hash (hash_str )
69+
70+ return new_node
71+
72+
73+ def _node_key (node ):
74+ return (node .type , node .start_point , node .end_point )
75+
76+
77+ class TokensToAST :
78+
79+ def __init__ (self , create_node_fn ):
80+ self .create_node_fn = create_node_fn
81+
82+ self .root_node = None
83+ self .waitlist = []
84+ self .node_index = {}
85+ self .child_count = defaultdict (int )
86+
87+ def _create_node (self , ast_node , text = None ):
88+ node_key = _node_key (ast_node )
89+ children = [self .node_index [_node_key (c )] for c in ast_node .children
90+ if _node_key (c ) in self .node_index ]
91+
92+ current_node = self .create_node_fn (ast_node .type , children , text = text )
93+ self .node_index [node_key ] = current_node
94+
95+ # Add parent if ready
96+ if ast_node .parent :
97+ parent_ast = ast_node .parent
98+ parent_key = _node_key (parent_ast )
99+ self .child_count [parent_key ] += 1
100+
101+ if len (parent_ast .children ) == self .child_count [parent_key ]:
102+ self .waitlist .append (parent_ast )
103+
104+ else :
105+ self .root_node = current_node
106+
107+
108+ def __call__ (self , tokens ):
109+
110+ token_nodes = ((t .text , t .ast_node ) for t in tokens if hasattr (t , "ast_node" ))
111+ for token_text , token_ast in token_nodes :
112+ self ._create_node (token_ast , text = token_text )
113+
114+ while len (self .waitlist ) > 0 :
115+ self ._create_node (self .waitlist .pop (0 ))
116+
117+ return self .root_node
118+
119+
120+
121+ # Interface ----------------------------------------------------------------
122+
123+ def parse_ast (source_code , lang = "guess" , ** kwargs ):
124+
125+ # Parse AST
126+ kwargs ["lang" ] = lang
127+ kwargs ["syntax_error" ] = "ignore"
128+
129+ ast_tokens = ct .tokenize (source_code , ** kwargs )
130+
131+ return TokensToAST (default_create_node )(ast_tokens )
0 commit comments