Skip to content

Commit 52acdf4

Browse files
authored
cli: Instantiate Codemods per file (#1334)
Instead of sharing instances of a Codemod across many files, this PR allows passing in a Codemod class to `parallel_exec_transform_with_prettyprint` which will then instantiate the Codemod for each file. `tool._codemod_impl` now starts using this API. The old behavior is deprecated, because sharing codemod instances across files is a surprising behavior, and causes hard-to-diagnose bugs when a Codemod keeps track of its state via instance variables.
1 parent d002c14 commit 52acdf4

3 files changed

Lines changed: 192 additions & 169 deletions

File tree

libcst/codemod/_cli.py

Lines changed: 150 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import traceback
1717
from concurrent.futures import as_completed, Executor, ProcessPoolExecutor
1818
from copy import deepcopy
19-
from dataclasses import dataclass, replace
19+
from dataclasses import dataclass
2020
from multiprocessing import cpu_count
2121
from pathlib import Path
22-
from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union
22+
from typing import AnyStr, cast, Dict, List, Optional, Sequence, Type, Union
23+
from warnings import warn
2324

2425
from libcst import parse_module, PartialParserConfig
2526
from libcst.codemod._codemod import Codemod
@@ -213,12 +214,52 @@ class ExecutionConfig:
213214
unified_diff: Optional[int] = None
214215

215216

216-
def _execute_transform( # noqa: C901
217-
transformer: Codemod,
217+
def _prepare_context(
218+
repo_root: str,
218219
filename: str,
219-
config: ExecutionConfig,
220220
scratch: Dict[str, object],
221-
) -> ExecutionResult:
221+
repo_manager: Optional[FullRepoManager],
222+
) -> CodemodContext:
223+
# determine the module and package name for this file
224+
try:
225+
module_name_and_package = calculate_module_and_package(repo_root, filename)
226+
mod_name = module_name_and_package.name
227+
pkg_name = module_name_and_package.package
228+
except ValueError as ex:
229+
print(f"Failed to determine module name for {filename}: {ex}", file=sys.stderr)
230+
mod_name = None
231+
pkg_name = None
232+
return CodemodContext(
233+
scratch=scratch,
234+
filename=filename,
235+
full_module_name=mod_name,
236+
full_package_name=pkg_name,
237+
metadata_manager=repo_manager,
238+
)
239+
240+
241+
def _instantiate_transformer(
242+
transformer: Union[Codemod, Type[Codemod]],
243+
repo_root: str,
244+
filename: str,
245+
original_scratch: Dict[str, object],
246+
codemod_kwargs: Dict[str, object],
247+
repo_manager: Optional[FullRepoManager],
248+
) -> Codemod:
249+
if isinstance(transformer, type):
250+
return transformer( # type: ignore
251+
context=_prepare_context(repo_root, filename, {}, repo_manager),
252+
**codemod_kwargs,
253+
)
254+
transformer.context = _prepare_context(
255+
repo_root, filename, deepcopy(original_scratch), repo_manager
256+
)
257+
return transformer
258+
259+
260+
def _check_for_skip(
261+
filename: str, config: ExecutionConfig
262+
) -> Union[ExecutionResult, bytes]:
222263
for pattern in config.blacklist_patterns:
223264
if re.fullmatch(pattern, filename):
224265
return ExecutionResult(
@@ -230,45 +271,46 @@ def _execute_transform( # noqa: C901
230271
),
231272
)
232273

233-
try:
234-
with open(filename, "rb") as fp:
235-
oldcode = fp.read()
274+
with open(filename, "rb") as fp:
275+
oldcode = fp.read()
236276

237-
# Skip generated files
238-
if (
239-
not config.include_generated
240-
and config.generated_code_marker.encode("utf-8") in oldcode
241-
):
242-
return ExecutionResult(
243-
filename=filename,
244-
changed=False,
245-
transform_result=TransformSkip(
246-
skip_reason=SkipReason.GENERATED,
247-
skip_description="Generated file.",
248-
),
249-
)
277+
# Skip generated files
278+
if (
279+
not config.include_generated
280+
and config.generated_code_marker.encode("utf-8") in oldcode
281+
):
282+
return ExecutionResult(
283+
filename=filename,
284+
changed=False,
285+
transform_result=TransformSkip(
286+
skip_reason=SkipReason.GENERATED,
287+
skip_description="Generated file.",
288+
),
289+
)
290+
return oldcode
250291

251-
# determine the module and package name for this file
252-
try:
253-
module_name_and_package = calculate_module_and_package(
254-
config.repo_root or ".", filename
255-
)
256-
mod_name = module_name_and_package.name
257-
pkg_name = module_name_and_package.package
258-
except ValueError as ex:
259-
print(
260-
f"Failed to determine module name for {filename}: {ex}", file=sys.stderr
261-
)
262-
mod_name = None
263-
pkg_name = None
264292

265-
# Apart from metadata_manager, every field of context should be reset per file
266-
transformer.context = CodemodContext(
267-
scratch=deepcopy(scratch),
268-
filename=filename,
269-
full_module_name=mod_name,
270-
full_package_name=pkg_name,
271-
metadata_manager=transformer.context.metadata_manager,
293+
def _execute_transform(
294+
transformer: Union[Codemod, Type[Codemod]],
295+
filename: str,
296+
config: ExecutionConfig,
297+
original_scratch: Dict[str, object],
298+
codemod_args: Optional[Dict[str, object]],
299+
repo_manager: Optional[FullRepoManager],
300+
) -> ExecutionResult:
301+
warnings: list[str] = []
302+
try:
303+
oldcode = _check_for_skip(filename, config)
304+
if isinstance(oldcode, ExecutionResult):
305+
return oldcode
306+
307+
transformer_instance = _instantiate_transformer(
308+
transformer,
309+
config.repo_root or ".",
310+
filename,
311+
original_scratch,
312+
codemod_args or {},
313+
repo_manager,
272314
)
273315

274316
# Run the transform, bail if we failed or if we aren't formatting code
@@ -281,55 +323,26 @@ def _execute_transform( # noqa: C901
281323
else PartialParserConfig()
282324
),
283325
)
284-
output_tree = transformer.transform_module(input_tree)
326+
output_tree = transformer_instance.transform_module(input_tree)
285327
newcode = output_tree.bytes
286328
encoding = output_tree.encoding
287-
except KeyboardInterrupt:
288-
return ExecutionResult(
289-
filename=filename, changed=False, transform_result=TransformExit()
290-
)
329+
warnings.extend(transformer_instance.context.warnings)
291330
except SkipFile as ex:
331+
warnings.extend(transformer_instance.context.warnings)
292332
return ExecutionResult(
293333
filename=filename,
294334
changed=False,
295335
transform_result=TransformSkip(
296336
skip_reason=SkipReason.OTHER,
297337
skip_description=str(ex),
298-
warning_messages=transformer.context.warnings,
299-
),
300-
)
301-
except Exception as ex:
302-
return ExecutionResult(
303-
filename=filename,
304-
changed=False,
305-
transform_result=TransformFailure(
306-
error=ex,
307-
traceback_str=traceback.format_exc(),
308-
warning_messages=transformer.context.warnings,
338+
warning_messages=warnings,
309339
),
310340
)
311341

312342
# Call formatter if needed, but only if we actually changed something in this
313343
# file
314344
if config.format_code and newcode != oldcode:
315-
try:
316-
newcode = invoke_formatter(config.formatter_args, newcode)
317-
except KeyboardInterrupt:
318-
return ExecutionResult(
319-
filename=filename,
320-
changed=False,
321-
transform_result=TransformExit(),
322-
)
323-
except Exception as ex:
324-
return ExecutionResult(
325-
filename=filename,
326-
changed=False,
327-
transform_result=TransformFailure(
328-
error=ex,
329-
traceback_str=traceback.format_exc(),
330-
warning_messages=transformer.context.warnings,
331-
),
332-
)
345+
newcode = invoke_formatter(config.formatter_args, newcode)
333346

334347
# Format as unified diff if needed, otherwise save it back
335348
changed = oldcode != newcode
@@ -352,13 +365,14 @@ def _execute_transform( # noqa: C901
352365
return ExecutionResult(
353366
filename=filename,
354367
changed=changed,
355-
transform_result=TransformSuccess(
356-
warning_messages=transformer.context.warnings, code=newcode
357-
),
368+
transform_result=TransformSuccess(warning_messages=warnings, code=newcode),
358369
)
370+
359371
except KeyboardInterrupt:
360372
return ExecutionResult(
361-
filename=filename, changed=False, transform_result=TransformExit()
373+
filename=filename,
374+
changed=False,
375+
transform_result=TransformExit(warning_messages=warnings),
362376
)
363377
except Exception as ex:
364378
return ExecutionResult(
@@ -367,7 +381,7 @@ def _execute_transform( # noqa: C901
367381
transform_result=TransformFailure(
368382
error=ex,
369383
traceback_str=traceback.format_exc(),
370-
warning_messages=transformer.context.warnings,
384+
warning_messages=warnings,
371385
),
372386
)
373387

@@ -504,15 +518,8 @@ class ParallelTransformResult:
504518
skips: int
505519

506520

507-
# Unfortunate wrapper required since there is no `istarmap_unordered`...
508-
def _execute_transform_wrap(
509-
job: Dict[str, Any],
510-
) -> ExecutionResult:
511-
return _execute_transform(**job)
512-
513-
514521
def parallel_exec_transform_with_prettyprint( # noqa: C901
515-
transform: Codemod,
522+
transform: Union[Codemod, Type[Codemod]],
516523
files: Sequence[str],
517524
*,
518525
jobs: Optional[int] = None,
@@ -528,38 +535,49 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
528535
blacklist_patterns: Sequence[str] = (),
529536
python_version: Optional[str] = None,
530537
repo_root: Optional[str] = None,
538+
codemod_args: Optional[Dict[str, object]] = None,
531539
) -> ParallelTransformResult:
532540
"""
533-
Given a list of files and an instantiated codemod we should apply to them,
534-
fork and apply the codemod in parallel to all of the files, including any
535-
configured formatter. The ``jobs`` parameter controls the maximum number of
536-
in-flight transforms, and needs to be at least 1. If not included, the number
537-
of jobs will automatically be set to the number of CPU cores. If ``unified_diff``
538-
is set to a number, changes to files will be printed to stdout with
539-
``unified_diff`` lines of context. If it is set to ``None`` or left out, files
540-
themselves will be updated with changes and formatting. If a
541-
``python_version`` is provided, then we will parse each source file using
542-
this version. Otherwise, we will use the version of the currently executing python
541+
Given a list of files and a codemod we should apply to them, fork and apply the
542+
codemod in parallel to all of the files, including any configured formatter. The
543+
``jobs`` parameter controls the maximum number of in-flight transforms, and needs to
544+
be at least 1. If not included, the number of jobs will automatically be set to the
545+
number of CPU cores. If ``unified_diff`` is set to a number, changes to files will
546+
be printed to stdout with ``unified_diff`` lines of context. If it is set to
547+
``None`` or left out, files themselves will be updated with changes and formatting.
548+
If a ``python_version`` is provided, then we will parse each source file using this
549+
version. Otherwise, we will use the version of the currently executing python
543550
binary.
544551
545-
A progress indicator as well as any generated warnings will be printed to stderr.
546-
To supress the interactive progress indicator, set ``hide_progress`` to ``True``.
547-
Files that include the generated code marker will be skipped unless the
548-
``include_generated`` parameter is set to ``True``. Similarly, files that match
549-
a supplied blacklist of regex patterns will be skipped. Warnings for skipping
550-
both blacklisted and generated files will be printed to stderr along with
551-
warnings generated by the codemod unless ``hide_blacklisted`` and
552-
``hide_generated`` are set to ``True``. Files that were successfully codemodded
553-
will not be printed to stderr unless ``show_successes`` is set to ``True``.
554-
555-
To make this API possible, we take an instantiated transform. This is due to
556-
the fact that lambdas are not pickleable and pickling functions is undefined.
557-
This means we're implicitly relying on fork behavior on UNIX-like systems, and
558-
this function will not work on Windows systems. To create a command-line utility
559-
that runs on Windows, please instead see
560-
:func:`~libcst.codemod.exec_transform_with_prettyprint`.
552+
A progress indicator as well as any generated warnings will be printed to stderr. To
553+
supress the interactive progress indicator, set ``hide_progress`` to ``True``. Files
554+
that include the generated code marker will be skipped unless the
555+
``include_generated`` parameter is set to ``True``. Similarly, files that match a
556+
supplied blacklist of regex patterns will be skipped. Warnings for skipping both
557+
blacklisted and generated files will be printed to stderr along with warnings
558+
generated by the codemod unless ``hide_blacklisted`` and ``hide_generated`` are set
559+
to ``True``. Files that were successfully codemodded will not be printed to stderr
560+
unless ``show_successes`` is set to ``True``.
561+
562+
We take a :class:`~libcst.codemod._codemod.Codemod` class, or an instantiated
563+
:class:`~libcst.codemod._codemod.Codemod`. In the former case, the codemod will be
564+
instantiated for each file, with ``codemod_args`` passed in to the constructor.
565+
Passing an already instantiated :class:`~libcst.codemod._codemod.Codemod` is
566+
deprecated, because it leads to sharing of the
567+
:class:`~libcst.codemod._codemod.Codemod` instance across files, which is a common
568+
source of hard-to-track-down bugs when the :class:`~libcst.codemod._codemod.Codemod`
569+
tracks its state on the instance.
561570
"""
562571

572+
if isinstance(transform, Codemod):
573+
warn(
574+
"Passing transformer instances to `parallel_exec_transform_with_prettyprint` "
575+
"is deprecated and will break in a future version. "
576+
"Please pass the transformer class instead.",
577+
DeprecationWarning,
578+
stacklevel=2,
579+
)
580+
563581
# Ensure that we have no duplicates, otherwise we might get race conditions
564582
# on write.
565583
files = sorted({os.path.abspath(f) for f in files})
@@ -579,6 +597,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
579597
if total == 0:
580598
return ParallelTransformResult(successes=0, failures=0, skips=0, warnings=0)
581599

600+
metadata_manager: Optional[FullRepoManager] = None
582601
if repo_root is not None:
583602
# Make sure if there is a root that we have the absolute path to it.
584603
repo_root = os.path.abspath(repo_root)
@@ -591,10 +610,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
591610
transform.get_inherited_dependencies(),
592611
)
593612
metadata_manager.resolve_cache()
594-
transform.context = replace(
595-
transform.context,
596-
metadata_manager=metadata_manager,
597-
)
613+
598614
print("Executing codemod...", file=sys.stderr)
599615

600616
config = ExecutionConfig(
@@ -630,19 +646,24 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
630646
failures: int = 0
631647
warnings: int = 0
632648
skips: int = 0
649+
original_scratch = (
650+
deepcopy(transform.context.scratch) if isinstance(transform, Codemod) else {}
651+
)
633652

634653
with pool_impl(max_workers=jobs) as executor: # type: ignore
635-
args = [
636-
{
637-
"transformer": transform,
638-
"filename": filename,
639-
"config": config,
640-
"scratch": transform.context.scratch,
641-
}
642-
for filename in files
643-
]
644654
try:
645-
futures = [executor.submit(_execute_transform_wrap, arg) for arg in args]
655+
futures = [
656+
executor.submit(
657+
_execute_transform,
658+
transformer=transform,
659+
filename=filename,
660+
config=config,
661+
original_scratch=original_scratch,
662+
codemod_args=codemod_args,
663+
repo_manager=metadata_manager,
664+
)
665+
for filename in files
666+
]
646667
for future in as_completed(futures):
647668
result = future.result()
648669
# Print an execution result, keep track of failures

0 commit comments

Comments
 (0)