1616import traceback
1717from concurrent .futures import as_completed , Executor , ProcessPoolExecutor
1818from copy import deepcopy
19- from dataclasses import dataclass , replace
19+ from dataclasses import dataclass
2020from multiprocessing import cpu_count
2121from pathlib import Path
22- from typing import Any , AnyStr , cast , Dict , List , Optional , Sequence , Union
22+ from typing import AnyStr , cast , Dict , List , Optional , Sequence , Type , Union
23+ from warnings import warn
2324
2425from libcst import parse_module , PartialParserConfig
2526from libcst .codemod ._codemod import Codemod
@@ -213,12 +214,52 @@ class ExecutionConfig:
213214 unified_diff : Optional [int ] = None
214215
215216
216- def _execute_transform ( # noqa: C901
217- transformer : Codemod ,
217+ def _prepare_context (
218+ repo_root : str ,
218219 filename : str ,
219- config : ExecutionConfig ,
220220 scratch : Dict [str , object ],
221- ) -> ExecutionResult :
221+ repo_manager : Optional [FullRepoManager ],
222+ ) -> CodemodContext :
223+ # determine the module and package name for this file
224+ try :
225+ module_name_and_package = calculate_module_and_package (repo_root , filename )
226+ mod_name = module_name_and_package .name
227+ pkg_name = module_name_and_package .package
228+ except ValueError as ex :
229+ print (f"Failed to determine module name for { filename } : { ex } " , file = sys .stderr )
230+ mod_name = None
231+ pkg_name = None
232+ return CodemodContext (
233+ scratch = scratch ,
234+ filename = filename ,
235+ full_module_name = mod_name ,
236+ full_package_name = pkg_name ,
237+ metadata_manager = repo_manager ,
238+ )
239+
240+
241+ def _instantiate_transformer (
242+ transformer : Union [Codemod , Type [Codemod ]],
243+ repo_root : str ,
244+ filename : str ,
245+ original_scratch : Dict [str , object ],
246+ codemod_kwargs : Dict [str , object ],
247+ repo_manager : Optional [FullRepoManager ],
248+ ) -> Codemod :
249+ if isinstance (transformer , type ):
250+ return transformer ( # type: ignore
251+ context = _prepare_context (repo_root , filename , {}, repo_manager ),
252+ ** codemod_kwargs ,
253+ )
254+ transformer .context = _prepare_context (
255+ repo_root , filename , deepcopy (original_scratch ), repo_manager
256+ )
257+ return transformer
258+
259+
260+ def _check_for_skip (
261+ filename : str , config : ExecutionConfig
262+ ) -> Union [ExecutionResult , bytes ]:
222263 for pattern in config .blacklist_patterns :
223264 if re .fullmatch (pattern , filename ):
224265 return ExecutionResult (
@@ -230,45 +271,46 @@ def _execute_transform( # noqa: C901
230271 ),
231272 )
232273
233- try :
234- with open (filename , "rb" ) as fp :
235- oldcode = fp .read ()
274+ with open (filename , "rb" ) as fp :
275+ oldcode = fp .read ()
236276
237- # Skip generated files
238- if (
239- not config .include_generated
240- and config .generated_code_marker .encode ("utf-8" ) in oldcode
241- ):
242- return ExecutionResult (
243- filename = filename ,
244- changed = False ,
245- transform_result = TransformSkip (
246- skip_reason = SkipReason .GENERATED ,
247- skip_description = "Generated file." ,
248- ),
249- )
277+ # Skip generated files
278+ if (
279+ not config .include_generated
280+ and config .generated_code_marker .encode ("utf-8" ) in oldcode
281+ ):
282+ return ExecutionResult (
283+ filename = filename ,
284+ changed = False ,
285+ transform_result = TransformSkip (
286+ skip_reason = SkipReason .GENERATED ,
287+ skip_description = "Generated file." ,
288+ ),
289+ )
290+ return oldcode
250291
251- # determine the module and package name for this file
252- try :
253- module_name_and_package = calculate_module_and_package (
254- config .repo_root or "." , filename
255- )
256- mod_name = module_name_and_package .name
257- pkg_name = module_name_and_package .package
258- except ValueError as ex :
259- print (
260- f"Failed to determine module name for { filename } : { ex } " , file = sys .stderr
261- )
262- mod_name = None
263- pkg_name = None
264292
265- # Apart from metadata_manager, every field of context should be reset per file
266- transformer .context = CodemodContext (
267- scratch = deepcopy (scratch ),
268- filename = filename ,
269- full_module_name = mod_name ,
270- full_package_name = pkg_name ,
271- metadata_manager = transformer .context .metadata_manager ,
293+ def _execute_transform (
294+ transformer : Union [Codemod , Type [Codemod ]],
295+ filename : str ,
296+ config : ExecutionConfig ,
297+ original_scratch : Dict [str , object ],
298+ codemod_args : Optional [Dict [str , object ]],
299+ repo_manager : Optional [FullRepoManager ],
300+ ) -> ExecutionResult :
301+ warnings : list [str ] = []
302+ try :
303+ oldcode = _check_for_skip (filename , config )
304+ if isinstance (oldcode , ExecutionResult ):
305+ return oldcode
306+
307+ transformer_instance = _instantiate_transformer (
308+ transformer ,
309+ config .repo_root or "." ,
310+ filename ,
311+ original_scratch ,
312+ codemod_args or {},
313+ repo_manager ,
272314 )
273315
274316 # Run the transform, bail if we failed or if we aren't formatting code
@@ -281,55 +323,26 @@ def _execute_transform( # noqa: C901
281323 else PartialParserConfig ()
282324 ),
283325 )
284- output_tree = transformer .transform_module (input_tree )
326+ output_tree = transformer_instance .transform_module (input_tree )
285327 newcode = output_tree .bytes
286328 encoding = output_tree .encoding
287- except KeyboardInterrupt :
288- return ExecutionResult (
289- filename = filename , changed = False , transform_result = TransformExit ()
290- )
329+ warnings .extend (transformer_instance .context .warnings )
291330 except SkipFile as ex :
331+ warnings .extend (transformer_instance .context .warnings )
292332 return ExecutionResult (
293333 filename = filename ,
294334 changed = False ,
295335 transform_result = TransformSkip (
296336 skip_reason = SkipReason .OTHER ,
297337 skip_description = str (ex ),
298- warning_messages = transformer .context .warnings ,
299- ),
300- )
301- except Exception as ex :
302- return ExecutionResult (
303- filename = filename ,
304- changed = False ,
305- transform_result = TransformFailure (
306- error = ex ,
307- traceback_str = traceback .format_exc (),
308- warning_messages = transformer .context .warnings ,
338+ warning_messages = warnings ,
309339 ),
310340 )
311341
312342 # Call formatter if needed, but only if we actually changed something in this
313343 # file
314344 if config .format_code and newcode != oldcode :
315- try :
316- newcode = invoke_formatter (config .formatter_args , newcode )
317- except KeyboardInterrupt :
318- return ExecutionResult (
319- filename = filename ,
320- changed = False ,
321- transform_result = TransformExit (),
322- )
323- except Exception as ex :
324- return ExecutionResult (
325- filename = filename ,
326- changed = False ,
327- transform_result = TransformFailure (
328- error = ex ,
329- traceback_str = traceback .format_exc (),
330- warning_messages = transformer .context .warnings ,
331- ),
332- )
345+ newcode = invoke_formatter (config .formatter_args , newcode )
333346
334347 # Format as unified diff if needed, otherwise save it back
335348 changed = oldcode != newcode
@@ -352,13 +365,14 @@ def _execute_transform( # noqa: C901
352365 return ExecutionResult (
353366 filename = filename ,
354367 changed = changed ,
355- transform_result = TransformSuccess (
356- warning_messages = transformer .context .warnings , code = newcode
357- ),
368+ transform_result = TransformSuccess (warning_messages = warnings , code = newcode ),
358369 )
370+
359371 except KeyboardInterrupt :
360372 return ExecutionResult (
361- filename = filename , changed = False , transform_result = TransformExit ()
373+ filename = filename ,
374+ changed = False ,
375+ transform_result = TransformExit (warning_messages = warnings ),
362376 )
363377 except Exception as ex :
364378 return ExecutionResult (
@@ -367,7 +381,7 @@ def _execute_transform( # noqa: C901
367381 transform_result = TransformFailure (
368382 error = ex ,
369383 traceback_str = traceback .format_exc (),
370- warning_messages = transformer . context . warnings ,
384+ warning_messages = warnings ,
371385 ),
372386 )
373387
@@ -504,15 +518,8 @@ class ParallelTransformResult:
504518 skips : int
505519
506520
507- # Unfortunate wrapper required since there is no `istarmap_unordered`...
508- def _execute_transform_wrap (
509- job : Dict [str , Any ],
510- ) -> ExecutionResult :
511- return _execute_transform (** job )
512-
513-
514521def parallel_exec_transform_with_prettyprint ( # noqa: C901
515- transform : Codemod ,
522+ transform : Union [ Codemod , Type [ Codemod ]] ,
516523 files : Sequence [str ],
517524 * ,
518525 jobs : Optional [int ] = None ,
@@ -528,38 +535,49 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
528535 blacklist_patterns : Sequence [str ] = (),
529536 python_version : Optional [str ] = None ,
530537 repo_root : Optional [str ] = None ,
538+ codemod_args : Optional [Dict [str , object ]] = None ,
531539) -> ParallelTransformResult :
532540 """
533- Given a list of files and an instantiated codemod we should apply to them,
534- fork and apply the codemod in parallel to all of the files, including any
535- configured formatter. The ``jobs`` parameter controls the maximum number of
536- in-flight transforms, and needs to be at least 1. If not included, the number
537- of jobs will automatically be set to the number of CPU cores. If ``unified_diff``
538- is set to a number, changes to files will be printed to stdout with
539- ``unified_diff`` lines of context. If it is set to ``None`` or left out, files
540- themselves will be updated with changes and formatting. If a
541- ``python_version`` is provided, then we will parse each source file using
542- this version. Otherwise, we will use the version of the currently executing python
541+ Given a list of files and a codemod we should apply to them, fork and apply the
542+ codemod in parallel to all of the files, including any configured formatter. The
543+ ``jobs`` parameter controls the maximum number of in-flight transforms, and needs to
544+ be at least 1. If not included, the number of jobs will automatically be set to the
545+ number of CPU cores. If ``unified_diff`` is set to a number, changes to files will
546+ be printed to stdout with ``unified_diff`` lines of context. If it is set to
547+ ``None`` or left out, files themselves will be updated with changes and formatting.
548+ If a ``python_version`` is provided, then we will parse each source file using this
549+ version. Otherwise, we will use the version of the currently executing python
543550 binary.
544551
545- A progress indicator as well as any generated warnings will be printed to stderr.
546- To supress the interactive progress indicator, set ``hide_progress`` to ``True``.
547- Files that include the generated code marker will be skipped unless the
548- ``include_generated`` parameter is set to ``True``. Similarly, files that match
549- a supplied blacklist of regex patterns will be skipped. Warnings for skipping
550- both blacklisted and generated files will be printed to stderr along with
551- warnings generated by the codemod unless ``hide_blacklisted`` and
552- ``hide_generated`` are set to ``True``. Files that were successfully codemodded
553- will not be printed to stderr unless ``show_successes`` is set to ``True``.
554-
555- To make this API possible, we take an instantiated transform. This is due to
556- the fact that lambdas are not pickleable and pickling functions is undefined.
557- This means we're implicitly relying on fork behavior on UNIX-like systems, and
558- this function will not work on Windows systems. To create a command-line utility
559- that runs on Windows, please instead see
560- :func:`~libcst.codemod.exec_transform_with_prettyprint`.
552+ A progress indicator as well as any generated warnings will be printed to stderr. To
553+ supress the interactive progress indicator, set ``hide_progress`` to ``True``. Files
554+ that include the generated code marker will be skipped unless the
555+ ``include_generated`` parameter is set to ``True``. Similarly, files that match a
556+ supplied blacklist of regex patterns will be skipped. Warnings for skipping both
557+ blacklisted and generated files will be printed to stderr along with warnings
558+ generated by the codemod unless ``hide_blacklisted`` and ``hide_generated`` are set
559+ to ``True``. Files that were successfully codemodded will not be printed to stderr
560+ unless ``show_successes`` is set to ``True``.
561+
562+ We take a :class:`~libcst.codemod._codemod.Codemod` class, or an instantiated
563+ :class:`~libcst.codemod._codemod.Codemod`. In the former case, the codemod will be
564+ instantiated for each file, with ``codemod_args`` passed in to the constructor.
565+ Passing an already instantiated :class:`~libcst.codemod._codemod.Codemod` is
566+ deprecated, because it leads to sharing of the
567+ :class:`~libcst.codemod._codemod.Codemod` instance across files, which is a common
568+ source of hard-to-track-down bugs when the :class:`~libcst.codemod._codemod.Codemod`
569+ tracks its state on the instance.
561570 """
562571
572+ if isinstance (transform , Codemod ):
573+ warn (
574+ "Passing transformer instances to `parallel_exec_transform_with_prettyprint` "
575+ "is deprecated and will break in a future version. "
576+ "Please pass the transformer class instead." ,
577+ DeprecationWarning ,
578+ stacklevel = 2 ,
579+ )
580+
563581 # Ensure that we have no duplicates, otherwise we might get race conditions
564582 # on write.
565583 files = sorted ({os .path .abspath (f ) for f in files })
@@ -579,6 +597,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
579597 if total == 0 :
580598 return ParallelTransformResult (successes = 0 , failures = 0 , skips = 0 , warnings = 0 )
581599
600+ metadata_manager : Optional [FullRepoManager ] = None
582601 if repo_root is not None :
583602 # Make sure if there is a root that we have the absolute path to it.
584603 repo_root = os .path .abspath (repo_root )
@@ -591,10 +610,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
591610 transform .get_inherited_dependencies (),
592611 )
593612 metadata_manager .resolve_cache ()
594- transform .context = replace (
595- transform .context ,
596- metadata_manager = metadata_manager ,
597- )
613+
598614 print ("Executing codemod..." , file = sys .stderr )
599615
600616 config = ExecutionConfig (
@@ -630,19 +646,24 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
630646 failures : int = 0
631647 warnings : int = 0
632648 skips : int = 0
649+ original_scratch = (
650+ deepcopy (transform .context .scratch ) if isinstance (transform , Codemod ) else {}
651+ )
633652
634653 with pool_impl (max_workers = jobs ) as executor : # type: ignore
635- args = [
636- {
637- "transformer" : transform ,
638- "filename" : filename ,
639- "config" : config ,
640- "scratch" : transform .context .scratch ,
641- }
642- for filename in files
643- ]
644654 try :
645- futures = [executor .submit (_execute_transform_wrap , arg ) for arg in args ]
655+ futures = [
656+ executor .submit (
657+ _execute_transform ,
658+ transformer = transform ,
659+ filename = filename ,
660+ config = config ,
661+ original_scratch = original_scratch ,
662+ codemod_args = codemod_args ,
663+ repo_manager = metadata_manager ,
664+ )
665+ for filename in files
666+ ]
646667 for future in as_completed (futures ):
647668 result = future .result ()
648669 # Print an execution result, keep track of failures
0 commit comments