diff --git a/libcst/codemod/_cli.py b/libcst/codemod/_cli.py index 2481bf9d1..d091ad8b0 100644 --- a/libcst/codemod/_cli.py +++ b/libcst/codemod/_cli.py @@ -14,16 +14,17 @@ import sys import time import traceback +from concurrent.futures import as_completed, Executor, ProcessPoolExecutor from copy import deepcopy from dataclasses import dataclass, replace -from multiprocessing import cpu_count, Pool +from multiprocessing import cpu_count from pathlib import Path from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union from libcst import parse_module, PartialParserConfig from libcst.codemod._codemod import Codemod from libcst.codemod._context import CodemodContext -from libcst.codemod._dummy_pool import DummyPool +from libcst.codemod._dummy_pool import DummyExecutor from libcst.codemod._runner import ( SkipFile, SkipReason, @@ -607,13 +608,14 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901 python_version=python_version, ) + pool_impl: type[Executor] if total == 1 or jobs == 1: # Simple case, we should not pay for process overhead. - # Let's just use a dummy synchronous pool. + # Let's just use a dummy synchronous executor. jobs = 1 - pool_impl = DummyPool + pool_impl = DummyExecutor else: - pool_impl = Pool + pool_impl = ProcessPoolExecutor # Warm the parser, pre-fork. parse_module( "", @@ -629,7 +631,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901 warnings: int = 0 skips: int = 0 - with pool_impl(processes=jobs) as p: # type: ignore + with pool_impl(max_workers=jobs) as executor: # type: ignore args = [ { "transformer": transform, @@ -640,9 +642,9 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901 for filename in files ] try: - for result in p.imap_unordered( - _execute_transform_wrap, args, chunksize=chunksize - ): + futures = [executor.submit(_execute_transform_wrap, arg) for arg in args] + for future in as_completed(futures): + result = future.result() # Print an execution result, keep track of failures _print_parallel_result( result, diff --git a/libcst/codemod/_dummy_pool.py b/libcst/codemod/_dummy_pool.py index c4a249326..34c911bdc 100644 --- a/libcst/codemod/_dummy_pool.py +++ b/libcst/codemod/_dummy_pool.py @@ -3,37 +3,50 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import sys +from concurrent.futures import Executor, Future from types import TracebackType -from typing import Callable, Generator, Iterable, Optional, Type, TypeVar +from typing import Callable, Optional, Type, TypeVar -RetT = TypeVar("RetT") -ArgT = TypeVar("ArgT") +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec +Return = TypeVar("Return") +Params = ParamSpec("Params") -class DummyPool: + +class DummyExecutor(Executor): """ - Synchronous dummy `multiprocessing.Pool` analogue. + Synchronous dummy `concurrent.futures.Executor` analogue. """ - def __init__(self, processes: Optional[int] = None) -> None: + def __init__(self, max_workers: Optional[int] = None) -> None: pass - def imap_unordered( + def submit( self, - func: Callable[[ArgT], RetT], - iterable: Iterable[ArgT], - chunksize: Optional[int] = None, - ) -> Generator[RetT, None, None]: - for args in iterable: - yield func(args) - - def __enter__(self) -> "DummyPool": + fn: Callable[Params, Return], + /, + *args: Params.args, + **kwargs: Params.kwargs, + ) -> Future[Return]: + future: Future[Return] = Future() + try: + result = fn(*args, **kwargs) + future.set_result(result) + except Exception as exc: + future.set_exception(exc) + return future + + def __enter__(self) -> "DummyExecutor": return self def __exit__( self, - exc_type: Optional[Type[Exception]], - exc: Optional[Exception], - tb: Optional[TracebackType], + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: pass