Skip to content

Commit 9fd67bc

Browse files
authored
fix certain matchers breaking under multiprocessing by initializing them late (#1204)
* Add is_property check Skip properties to prevent exceptions * Delayed initialization of matchers To support multiprocessing on Windows/macOS Issue #1181 * Add a test for matcher decorators with multiprocessing
1 parent 6a059be commit 9fd67bc

2 files changed

Lines changed: 83 additions & 43 deletions

File tree

libcst/codemod/tests/test_codemod_cli.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,34 @@ def baz() -> str:
9393
"- 3 warnings were generated.",
9494
output.stderr,
9595
)
96+
97+
def test_matcher_decorators_multiprocessing(self) -> None:
98+
file_count = 5
99+
code = """
100+
def baz(): # type: int
101+
return 5
102+
"""
103+
with tempfile.TemporaryDirectory() as tmpdir:
104+
p = Path(tmpdir)
105+
# Using more than chunksize=4 files to trigger multiprocessing
106+
for i in range(file_count):
107+
(p / f"mod{i}.py").write_text(CodemodTest.make_fixture_data(code))
108+
output = subprocess.run(
109+
[
110+
sys.executable,
111+
"-m",
112+
"libcst.tool",
113+
"codemod",
114+
# Good candidate since it uses matcher decorators
115+
"convert_type_comments.ConvertTypeComments",
116+
str(p),
117+
"--jobs",
118+
str(file_count),
119+
],
120+
encoding="utf-8",
121+
stderr=subprocess.PIPE,
122+
)
123+
self.assertIn(
124+
f"Transformed {file_count} files successfully.",
125+
output.stderr,
126+
)

libcst/matchers/_visitors.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ class UnionType:
6060
}
6161

6262

63+
def is_property(obj: object, attr_name: str) -> bool:
64+
"""Check if obj.attr is a property without evaluating it."""
65+
return isinstance(getattr(type(obj), attr_name, None), property)
66+
67+
6368
# pyre-ignore We don't care about Any here, its not exposed.
6469
def _match_decorator_unpickler(kwargs: Any) -> "MatchDecoratorMismatch":
6570
return MatchDecoratorMismatch(**kwargs)
@@ -265,20 +270,22 @@ def _check_types(
265270
)
266271

267272

268-
def _gather_matchers(obj: object) -> Set[BaseMatcherNode]:
269-
visit_matchers: Set[BaseMatcherNode] = set()
273+
def _gather_matchers(obj: object) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
274+
"""
275+
Set of gating matchers that we need to track and evaluate. We use these
276+
in conjunction with the call_if_inside and call_if_not_inside decorators
277+
to determine whether to call a visit/leave function.
278+
"""
270279

271-
for func in dir(obj):
272-
try:
273-
for matcher in getattr(getattr(obj, func), VISIT_POSITIVE_MATCHER_ATTR, []):
274-
visit_matchers.add(cast(BaseMatcherNode, matcher))
275-
for matcher in getattr(getattr(obj, func), VISIT_NEGATIVE_MATCHER_ATTR, []):
276-
visit_matchers.add(cast(BaseMatcherNode, matcher))
277-
except Exception:
278-
# This could be a caculated property, and calling getattr() evaluates it.
279-
# We have no control over the implementation detail, so if it raises, we
280-
# should not crash.
281-
pass
280+
visit_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {}
281+
282+
for attr_name in dir(obj):
283+
if not is_property(obj, attr_name):
284+
func = getattr(obj, attr_name)
285+
for matcher in getattr(func, VISIT_POSITIVE_MATCHER_ATTR, []):
286+
visit_matchers[cast(BaseMatcherNode, matcher)] = None
287+
for matcher in getattr(func, VISIT_NEGATIVE_MATCHER_ATTR, []):
288+
visit_matchers[cast(BaseMatcherNode, matcher)] = None
282289

283290
return visit_matchers
284291

@@ -302,16 +309,12 @@ def _gather_constructed_visit_funcs(
302309
] = {}
303310

304311
for funcname in dir(obj):
305-
try:
306-
possible_func = getattr(obj, funcname)
307-
if not ismethod(possible_func):
308-
continue
309-
func = cast(Callable[[cst.CSTNode], None], possible_func)
310-
except Exception:
311-
# This could be a caculated property, and calling getattr() evaluates it.
312-
# We have no control over the implementation detail, so if it raises, we
313-
# should not crash.
312+
if is_property(obj, funcname):
313+
continue
314+
possible_func = getattr(obj, funcname)
315+
if not ismethod(possible_func):
314316
continue
317+
func = cast(Callable[[cst.CSTNode], None], possible_func)
315318
matchers = getattr(func, CONSTRUCTED_VISIT_MATCHER_ATTR, [])
316319
if matchers:
317320
# Make sure that we aren't accidentally putting a @visit on a visit_Node.
@@ -337,16 +340,12 @@ def _gather_constructed_leave_funcs(
337340
] = {}
338341

339342
for funcname in dir(obj):
340-
try:
341-
possible_func = getattr(obj, funcname)
342-
if not ismethod(possible_func):
343-
continue
344-
func = cast(Callable[[cst.CSTNode], None], possible_func)
345-
except Exception:
346-
# This could be a caculated property, and calling getattr() evaluates it.
347-
# We have no control over the implementation detail, so if it raises, we
348-
# should not crash.
343+
if is_property(obj, funcname):
344+
continue
345+
possible_func = getattr(obj, funcname)
346+
if not ismethod(possible_func):
349347
continue
348+
func = cast(Callable[[cst.CSTNode], None], possible_func)
350349
matchers = getattr(func, CONSTRUCTED_LEAVE_MATCHER_ATTR, [])
351350
if matchers:
352351
# Make sure that we aren't accidentally putting a @leave on a leave_Node.
@@ -448,12 +447,7 @@ class MatcherDecoratableTransformer(CSTTransformer):
448447

449448
def __init__(self) -> None:
450449
CSTTransformer.__init__(self)
451-
# List of gating matchers that we need to track and evaluate. We use these
452-
# in conjuction with the call_if_inside and call_if_not_inside decorators
453-
# to determine whether or not to call a visit/leave function.
454-
self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {
455-
m: None for m in _gather_matchers(self)
456-
}
450+
self.__matchers: Optional[Dict[BaseMatcherNode, Optional[cst.CSTNode]]] = None
457451
# Mapping of matchers to functions. If in the course of visiting the tree,
458452
# a node matches one of these matchers, the corresponding function will be
459453
# called as if it was a visit_* method.
@@ -486,6 +480,16 @@ def __init__(self) -> None:
486480
expected_none_return=False,
487481
)
488482

483+
@property
484+
def _matchers(self) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
485+
if self.__matchers is None:
486+
self.__matchers = _gather_matchers(self)
487+
return self.__matchers
488+
489+
@_matchers.setter
490+
def _matchers(self, value: Dict[BaseMatcherNode, Optional[cst.CSTNode]]) -> None:
491+
self.__matchers = value
492+
489493
def on_visit(self, node: cst.CSTNode) -> bool:
490494
# First, evaluate any matchers that we have which we are not inside already.
491495
self._matchers = _visit_matchers(self._matchers, node, self)
@@ -660,12 +664,7 @@ class MatcherDecoratableVisitor(CSTVisitor):
660664

661665
def __init__(self) -> None:
662666
CSTVisitor.__init__(self)
663-
# List of gating matchers that we need to track and evaluate. We use these
664-
# in conjuction with the call_if_inside and call_if_not_inside decorators
665-
# to determine whether or not to call a visit/leave function.
666-
self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {
667-
m: None for m in _gather_matchers(self)
668-
}
667+
self.__matchers: Optional[Dict[BaseMatcherNode, Optional[cst.CSTNode]]] = None
669668
# Mapping of matchers to functions. If in the course of visiting the tree,
670669
# a node matches one of these matchers, the corresponding function will be
671670
# called as if it was a visit_* method.
@@ -693,6 +692,16 @@ def __init__(self) -> None:
693692
expected_none_return=True,
694693
)
695694

695+
@property
696+
def _matchers(self) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
697+
if self.__matchers is None:
698+
self.__matchers = _gather_matchers(self)
699+
return self.__matchers
700+
701+
@_matchers.setter
702+
def _matchers(self, value: Dict[BaseMatcherNode, Optional[cst.CSTNode]]) -> None:
703+
self.__matchers = value
704+
696705
def on_visit(self, node: cst.CSTNode) -> bool:
697706
# First, evaluate any matchers that we have which we are not inside already.
698707
self._matchers = _visit_matchers(self._matchers, node, self)

0 commit comments

Comments
 (0)