diff --git a/loky/__init__.py b/loky/__init__.py index a3b30c09..2bf079d3 100644 --- a/loky/__init__.py +++ b/loky/__init__.py @@ -1,8 +1,9 @@ -r"""The :mod:`loky` module manages a pool of worker that can be re-used across time. -It provides a robust and dynamic implementation os the +r"""The :mod:`loky` module manages a pool of worker that can be re-used across +time. It provides a robust and dynamic implementation of the :class:`ProcessPoolExecutor` and a function :func:`get_reusable_executor` which hide the pool management under the hood. """ + from concurrent.futures import ( ALL_COMPLETED, FIRST_COMPLETED, @@ -20,6 +21,7 @@ from .reusable_executor import get_reusable_executor from .cloudpickle_wrapper import wrap_non_picklable_objects from .process_executor import BrokenProcessPool, ProcessPoolExecutor +from .process_executor import get_worker_rank __all__ = [ @@ -37,6 +39,7 @@ "FIRST_EXCEPTION", "ALL_COMPLETED", "wrap_non_picklable_objects", + "get_worker_rank", "set_loky_pickler", ] diff --git a/loky/process_executor.py b/loky/process_executor.py index 1e08cc21..77266270 100644 --- a/loky/process_executor.py +++ b/loky/process_executor.py @@ -115,6 +115,33 @@ def _get_memory_usage(pid, force_gc=False): except ImportError: _USE_PSUTIL = False +# Mechanism to obtain the rank of a worker and the total number of workers in +# the executor. +_WORKER_RANK = None +_WORKER_WORLD = None + + +def get_worker_rank(): + """Returns the rank of the worker and the number of workers in the executor + + This helper function should only be called in a worker, else it will throw + a RuntimeError. + """ + if _WORKER_RANK is None: + raise RuntimeError( + "get_worker_id, should only be called in a worker, not in the " + "main process." + ) + return _WORKER_RANK, _WORKER_WORLD + + +def set_worker_rank(pid, rank_mapper): + """Set worker's rank and world size from the process pid and an rank_mapper.""" + global _WORKER_RANK, _WORKER_WORLD + if pid in rank_mapper: + _WORKER_RANK = rank_mapper[pid] + _WORKER_WORLD = rank_mapper["world"] + class _ThreadWakeup: def __init__(self): @@ -277,11 +304,12 @@ def __init__(self, work_id, exception=None, result=None): class _CallItem: - def __init__(self, work_id, fn, args, kwargs): + def __init__(self, work_id, fn, args, kwargs, rank_mapper): self.work_id = work_id self.fn = fn self.args = args self.kwargs = kwargs + self.rank_mapper = rank_mapper # Store the current loky_pickler so it is correctly set in the worker self.loky_pickler = get_loky_pickler_name() @@ -384,6 +412,7 @@ def _process_worker( timeout, worker_exit_lock, current_depth, + rank_mapper, ): """Evaluates calls from call_queue and places the results in result_queue. @@ -403,6 +432,8 @@ def _process_worker( worker_exit_lock: Lock to avoid flagging the executor as broken on workers timeout. current_depth: Nested parallelism level, to avoid infinite spawning. + rank_mapper: Initial value for rank and world as a dict with keys None + and world. """ if initializer is not None: try: @@ -420,6 +451,13 @@ def _process_worker( _last_memory_leak_check = None pid = os.getpid() + # Passing an initial value is necessary as some jobs can be sent and + # serialized before this worker is created. In this case, no rank is + # available in the call_item.rank_mapper and this rank is the correct one. + # When initialized, main process does not know the pid and pass the worker + # rank as None. + set_worker_rank(None, rank_mapper) + mp.util.debug(f"Worker started with timeout={timeout}") while True: try: @@ -447,6 +485,7 @@ def _process_worker( if call_item is None: # Notify queue management thread about worker shutdown result_queue.put(pid) + is_clean = worker_exit_lock.acquire(True, timeout=30) # Early notify any loky executor running in this worker process @@ -459,6 +498,10 @@ def _process_worker( else: mp.util.info("Main process did not release worker_exit") return + + # If the executor has been resized, this new rank mapper might contain + # new rank/world info. Correct the value before runnning the task. + set_worker_rank(pid, call_item.rank_mapper) try: r = call_item() except BaseException as e: @@ -583,6 +626,10 @@ def weakref_cb( # of new processes or shut down self.processes_management_lock = executor._processes_management_lock + # A dict mapping the workers' pid to their rank. Also contains the + # current size of the executor associated to 'world' key. + self.rank_mapper = executor._rank_mapper + super().__init__(name="ExecutorManagerThread") if sys.version_info < (3, 9): self.daemon = True @@ -634,6 +681,7 @@ def add_call_item_to_queue(self): work_item.fn, work_item.args, work_item.kwargs, + self.rank_mapper, ), block=True, ) @@ -727,6 +775,7 @@ def process_result_item(self, result_item): # itself: we should not mark the executor as broken. with self.processes_management_lock: p = self.processes.pop(result_item, None) + del self.rank_mapper[result_item] # p can be None if the executor is concurrently shutting down. if p is not None: @@ -1014,7 +1063,6 @@ class TerminatedWorkerError(BrokenProcessPool): class ShutdownExecutorError(RuntimeError): - """ Raised when a ProcessPoolExecutor is shutdown while a future was in the running or pending state. @@ -1128,6 +1176,10 @@ def __init__( # Finally setup the queues for interprocess communication self._setup_queues(job_reducers, result_reducers) + # A dict mapping the workers' pid to their rank. The current size of + # the executor is associated with the 'world' key. + self._rank_mapper = {"world": max_workers} + mp.util.debug("ProcessPoolExecutor is setup") def _setup_queues(self, job_reducers, result_reducers, queue_size=None): @@ -1184,8 +1236,16 @@ def _start_executor_manager_thread(self): ) def _adjust_process_count(self): + # Compute available worker ranks for newly spawned workers + given_ranks = set( + v for k, v in self._rank_mapper.items() if k != "world" + ) + all_ranks = set(range(self._max_workers)) + available_ranks = all_ranks - given_ranks + while len(self._processes) < self._max_workers: worker_exit_lock = self._context.BoundedSemaphore(1) + rank = available_ranks.pop() args = ( self._call_queue, self._result_queue, @@ -1195,6 +1255,7 @@ def _adjust_process_count(self): self._timeout, worker_exit_lock, _CURRENT_DEPTH + 1, + {None: rank, "world": self._max_workers}, ) worker_exit_lock.acquire() try: @@ -1208,6 +1269,14 @@ def _adjust_process_count(self): p._worker_exit_lock = worker_exit_lock p.start() self._processes[p.pid] = p + self._rank_mapper[p.pid] = rank + + # Reassign rank that are too high to rank that are still available. + # They will be passed to the workers when sending the tasks with + # the CallItem. + for pid, rank in list(self._rank_mapper.items()): + if pid != "world" and rank >= self._max_workers: + self._rank_mapper[pid] = available_ranks.pop() mp.util.debug( f"Adjusted process count to {self._max_workers}: " f"{[(p.name, pid) for pid, p in self._processes.items()]}" diff --git a/loky/reusable_executor.py b/loky/reusable_executor.py index d879da9e..2755a14e 100644 --- a/loky/reusable_executor.py +++ b/loky/reusable_executor.py @@ -236,10 +236,14 @@ def _resize(self, max_workers): # then no processes have been spawned and we can just # update _max_workers and return self._max_workers = max_workers + self._rank_mapper["world"] = max_workers return self._wait_job_completion() + # Set the new size to be broadcasted to the workers + self._rank_mapper["world"] = max_workers + # Some process might have returned due to timeout so check how many # children are still alive. Use the _process_management_lock to # ensure that no process are spawned or timeout during the resize. diff --git a/tests/_test_process_executor.py b/tests/_test_process_executor.py index f58af9f8..3c07cada 100644 --- a/tests/_test_process_executor.py +++ b/tests/_test_process_executor.py @@ -16,7 +16,9 @@ from math import sqrt from pickle import PicklingError from threading import Thread +import multiprocessing as mp from collections import defaultdict + from concurrent import futures from concurrent.futures._base import ( PENDING, @@ -1125,6 +1127,32 @@ def test_child_env_executor(self): executor.shutdown(wait=True) + @staticmethod + def _worker_rank(x): + time.sleep(0.2) + rank, world = loky.get_worker_rank() + return dict( + pid=os.getpid(), + name=mp.current_process().name, + rank=rank, + world=world, + ) + + @pytest.mark.parametrize("max_workers", [1, 5, 13]) + @pytest.mark.parametrize("timeout", [None, 0.01]) + def test_workers_rank(self, max_workers, timeout): + executor = self.executor_type(max_workers, timeout=timeout) + results = executor.map(self._worker_rank, range(max_workers * 5)) + workers_rank = {} + for f in results: + assert f["world"] == max_workers + rank = workers_rank.get(f["pid"], None) + assert rank is None or rank == f["rank"] + workers_rank[f["pid"]] = f["rank"] + msg = ", ".join(f"{k}, {v}" for k, v in executor._rank_mapper.items()) + assert set(workers_rank.values()) == set(range(max_workers)), msg + executor.shutdown(wait=True, kill_workers=True) + def test_viztracer_profiler(self): # Check that viztracer profiler is initialzed in workers when # installed. diff --git a/tests/test_loky_module.py b/tests/test_loky_module.py index b9734eac..a7e5a2a3 100644 --- a/tests/test_loky_module.py +++ b/tests/test_loky_module.py @@ -1,15 +1,16 @@ -import multiprocessing as mp import os import sys import shutil import tempfile import warnings +import multiprocessing as mp from subprocess import check_output import pytest import loky from loky import cpu_count +from loky import get_worker_rank from loky.backend.context import _cpu_count_user, _MAX_WINDOWS_WORKERS @@ -215,3 +216,8 @@ def test_only_physical_cores_with_user_limitation(): if cpu_count_user < cpu_count_mp: assert cpu_count() == cpu_count_user assert cpu_count(only_physical_cores=True) == cpu_count_user + + +def test_worker_rank_in_worker_only(): + with pytest.raises(RuntimeError): + get_worker_rank() diff --git a/tests/test_reusable_executor.py b/tests/test_reusable_executor.py index 305ce85f..12355075 100644 --- a/tests/test_reusable_executor.py +++ b/tests/test_reusable_executor.py @@ -1,21 +1,25 @@ import os -import subprocess import sys import gc +import sys +import time import ctypes -from tempfile import NamedTemporaryFile import pytest import warnings import threading +import subprocess from time import sleep +import multiprocessing as mp from multiprocessing import util, current_process from pickle import PicklingError, UnpicklingError +from tempfile import NamedTemporaryFile import cloudpickle from packaging.version import Version import loky from loky import cpu_count +from loky import get_worker_rank from loky import get_reusable_executor from loky.process_executor import _RemoteTraceback, TerminatedWorkerError from loky.process_executor import BrokenProcessPool, ShutdownExecutorError @@ -679,6 +683,39 @@ def test_resize_after_timeout(self): expected_msg = "A worker stopped" assert expected_msg in recorded_warnings[0].message.args[0] + @staticmethod + def _worker_rank(x): + time.sleep(0.2) + rank, world = get_worker_rank() + return dict( + pid=os.getpid(), + name=mp.current_process().name, + rank=rank, + world=world, + ) + + def test_workers_rank_resize(self): + + executor = get_reusable_executor(max_workers=2) + + with warnings.catch_warnings(record=True): + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + for size in [12, 2, 1, 12, 6, 1, 8, 5]: + executor = get_reusable_executor(max_workers=size, reuse=True) + results = executor.map(self._worker_rank, range(size * 5)) + executor.map(sleep, [0.01] * 6) + workers_rank = {} + for f in results: + assert f["world"] == size + rank = workers_rank.get(f["pid"], None) + assert rank is None or rank == f["rank"] + workers_rank[f["pid"]] = f["rank"] + msg = ", ".join( + f"{k}: {v}" for k, v in executor._rank_mapper.items() + ) + assert set(workers_rank.values()) == set(range(size)), msg + class TestGetReusableExecutor(ReusableExecutorMixin): def test_invalid_process_number(self):