Skip to content

Commit ee48e14

Browse files
committed
AST based code diff
1 parent 27cbd2d commit ee48e14

3 files changed

Lines changed: 270 additions & 0 deletions

File tree

code_diff/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from .ast import parse_ast
2+
3+
4+
# Main method --------------------------------------------------------
5+
6+
def difference(source, target, lang = "guess", **kwargs):
7+
8+
source_ast = parse_ast(source, lang = lang, **kwargs)
9+
target_ast = parse_ast(target, lang = lang, **kwargs)
10+
11+
# Concretize Diff
12+
source_ast, target_ast = diff_search(source_ast, target_ast)
13+
14+
return ASTDiff(source_ast, target_ast)
15+
16+
17+
# Diff Search --------------------------------------------------------
18+
# Run BFS until we find a node with at least two diffs
19+
20+
def diff_search(source_ast, target_ast):
21+
if source_ast is None or source_ast.isomorph(target_ast): return None, None
22+
23+
queue = [(source_ast, target_ast)]
24+
while len(queue) > 0:
25+
source_node, target_node = queue.pop(0)
26+
27+
if len(source_node.children) != len(target_node.children):
28+
return (source_node, target_node)
29+
30+
next_children = []
31+
for i, source_child in enumerate(source_node.children):
32+
target_child = target_node.children[i]
33+
34+
if not source_child.isomorph(target_child):
35+
next_children.append((source_child, target_child))
36+
37+
if len(next_children) == 1:
38+
queue.append(next_children[0])
39+
else:
40+
return (source_node, target_node)
41+
42+
43+
44+
45+
# AST Difference --------------------------------------------------------
46+
47+
class ASTDiff:
48+
49+
def __init__(self, source_ast, target_ast):
50+
self.source_ast = source_ast
51+
self.target_ast = target_ast

code_diff/ast.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import code_tokenize as ct
2+
3+
from collections import defaultdict
4+
5+
# AST Node ----------------------------------------------------------------
6+
7+
8+
class ASTNode(object):
9+
10+
def __init__(self, type, text = None, parent = None, children = None):
11+
12+
# Basic node attributes
13+
self.type = type
14+
self.children = children if children is not None else []
15+
self.parent = parent
16+
self.text = text # If text is not None, then leaf node
17+
18+
# Tree based attributes
19+
self.subtree_hash = None
20+
self.subtree_height = 0
21+
self.subtree_weight = 1
22+
23+
def isomorph(self, other):
24+
return ((self.subtree_hash, self.type, self.subtree_height, self.subtree_weight) ==
25+
(other.subtree_hash, other.type, other.subtree_height, other.subtree_weight))
26+
27+
28+
def sexp(self):
29+
name = self.text if self.text is not None else self.type
30+
31+
child_sexp = []
32+
for child in self.children:
33+
text = child.sexp()
34+
text = [" " + t for t in text.splitlines()]
35+
child_sexp.append("\n".join(text))
36+
37+
if len(child_sexp) == 0:
38+
return name
39+
40+
return "%s {\n%s\n}" % (name, " ".join(child_sexp))
41+
42+
def __repr__(self):
43+
attrs = {"type": self.type, "text": self.text}
44+
return "ASTNode(%s)" % (", ".join(["%s=%s" % (k, v) for k, v in attrs.items() if v is not None]))
45+
46+
47+
def default_create_node(type, children, text = None):
48+
new_node = ASTNode(type, text = text, children = children)
49+
50+
# Subtree metrics
51+
height = 0
52+
weight = 1
53+
hash_str = []
54+
55+
for child in children:
56+
child.parent = new_node # Set parent relation
57+
height = max(child.subtree_height + 1, height)
58+
weight += child.subtree_weight
59+
hash_str.append(str(child.subtree_hash))
60+
61+
new_node.subtree_height = height
62+
new_node.subtree_weight = weight
63+
64+
# WL hash subtree representation
65+
base_str = new_node.type if new_node.text is None else new_node.text
66+
hash_str.insert(0, base_str)
67+
hash_str = "_".join(hash_str)
68+
new_node.subtree_hash = hash(hash_str)
69+
70+
return new_node
71+
72+
73+
def _node_key(node):
74+
return (node.type, node.start_point, node.end_point)
75+
76+
77+
class TokensToAST:
78+
79+
def __init__(self, create_node_fn):
80+
self.create_node_fn = create_node_fn
81+
82+
self.root_node = None
83+
self.waitlist = []
84+
self.node_index = {}
85+
self.child_count = defaultdict(int)
86+
87+
def _create_node(self, ast_node, text = None):
88+
node_key = _node_key(ast_node)
89+
children = [self.node_index[_node_key(c)] for c in ast_node.children
90+
if _node_key(c) in self.node_index]
91+
92+
current_node = self.create_node_fn(ast_node.type, children, text = text)
93+
self.node_index[node_key] = current_node
94+
95+
# Add parent if ready
96+
if ast_node.parent:
97+
parent_ast = ast_node.parent
98+
parent_key = _node_key(parent_ast)
99+
self.child_count[parent_key] += 1
100+
101+
if len(parent_ast.children) == self.child_count[parent_key]:
102+
self.waitlist.append(parent_ast)
103+
104+
else:
105+
self.root_node = current_node
106+
107+
108+
def __call__(self, tokens):
109+
110+
token_nodes = ((t.text, t.ast_node) for t in tokens if hasattr(t, "ast_node"))
111+
for token_text, token_ast in token_nodes:
112+
self._create_node(token_ast, text = token_text)
113+
114+
while len(self.waitlist) > 0:
115+
self._create_node(self.waitlist.pop(0))
116+
117+
return self.root_node
118+
119+
120+
121+
# Interface ----------------------------------------------------------------
122+
123+
def parse_ast(source_code, lang = "guess", **kwargs):
124+
125+
# Parse AST
126+
kwargs["lang"] = lang
127+
kwargs["syntax_error"] = "ignore"
128+
129+
ast_tokens = ct.tokenize(source_code, **kwargs)
130+
131+
return TokensToAST(default_create_node)(ast_tokens)

code_diff/diff_utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import re
2+
3+
4+
# Diff parsing -----------------------------------------------------------------
5+
6+
class Hunk:
7+
8+
def __init__(self, lines, added_lines, rm_lines):
9+
self.lines = lines
10+
self.added_lines = set(added_lines)
11+
self.rm_lines = set(rm_lines)
12+
13+
14+
@property
15+
def after(self):
16+
17+
alines = []
18+
19+
for i, line in enumerate(self.lines):
20+
if i in self.rm_lines: continue
21+
if i in self.added_lines:
22+
alines.append(" " + line[1:])
23+
else:
24+
alines.append(line)
25+
26+
return "".join(alines)
27+
28+
29+
@property
30+
def before(self):
31+
32+
alines = []
33+
34+
for i, line in enumerate(self.lines):
35+
if i in self.added_lines: continue
36+
if i in self.rm_lines:
37+
alines.append(" " + line[1:])
38+
else:
39+
alines.append(line)
40+
41+
return "".join(alines)
42+
43+
def __repr__(self):
44+
return "".join(self.lines)
45+
46+
47+
def _parse_hunk(lines, start, end):
48+
49+
hunk_lines = lines[start + 1:end]
50+
51+
added_lines = []
52+
rm_lines = []
53+
54+
for i, hline in enumerate(hunk_lines):
55+
if hline.startswith("+"): added_lines.append(i)
56+
if hline.startswith("-"): rm_lines.append(i)
57+
58+
return Hunk(hunk_lines, added_lines, rm_lines)
59+
60+
61+
hunk_pat = re.compile("@@ -(\d+)(,\d+)? \+(\d+)(,\d+)? @@.*")
62+
63+
def parse_hunks(diff):
64+
lines = diff.splitlines(True)
65+
66+
hunks = []
67+
68+
start_ix = -1
69+
end_ix = -1
70+
71+
for line_ix, line in enumerate(lines):
72+
73+
if hunk_pat.match(line):
74+
75+
end_ix = line_ix - 1
76+
77+
if start_ix >= 0 and start_ix < end_ix:
78+
hunks.append(_parse_hunk(lines, start_ix, end_ix))
79+
80+
start_ix = line_ix
81+
82+
end_ix = len(lines)
83+
84+
if start_ix >= 0 and start_ix < end_ix:
85+
hunks.append(_parse_hunk(lines, start_ix, end_ix))
86+
87+
return hunks
88+

0 commit comments

Comments
 (0)