Skip to content
2 changes: 2 additions & 0 deletions loky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .reusable_executor import get_reusable_executor
from .cloudpickle_wrapper import wrap_non_picklable_objects
from .process_executor import BrokenProcessPool, ProcessPoolExecutor
from .worker_id import get_worker_id


__all__ = [
Expand All @@ -37,6 +38,7 @@
"FIRST_EXCEPTION",
"ALL_COMPLETED",
"wrap_non_picklable_objects",
"get_worker_id",
"set_loky_pickler",
]

Expand Down
31 changes: 29 additions & 2 deletions loky/process_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def _process_worker(
timeout,
worker_exit_lock,
current_depth,
worker_id,
):
"""Evaluates calls from call_queue and places the results in result_queue.

Expand Down Expand Up @@ -420,6 +421,9 @@ def _process_worker(
_last_memory_leak_check = None
pid = os.getpid()

# set the worker_id environment variable
os.environ["LOKY_WORKER_ID"] = str(worker_id)

mp.util.debug(f"Worker started with timeout={timeout}")
while True:
try:
Expand Down Expand Up @@ -562,6 +566,9 @@ def weakref_cb(
# A list of the ctx.Process instances used as workers.
self.processes = executor._processes

# A dict mapping worker pids to worker IDs
self.process_worker_ids = executor._process_worker_ids

# A ctx.Queue that will be filled with _CallItems derived from
# _WorkItems for processing by the process workers.
self.call_queue = executor._call_queue
Expand Down Expand Up @@ -727,6 +734,7 @@ def process_result_item(self, result_item):
# itself: we should not mark the executor as broken.
with self.processes_management_lock:
p = self.processes.pop(result_item, None)
self.process_worker_ids.pop(result_item, None)

# p can be None if the executor is concurrently shutting down.
if p is not None:
Expand Down Expand Up @@ -830,7 +838,9 @@ def kill_workers(self, reason=""):
# terminates descendant workers of the children in case there is some
# nested parallelism.
while self.processes:
_, p = self.processes.popitem()
pid, p = self.processes.popitem()
self.process_worker_ids.pop(pid, None)

mp.util.debug(f"terminate process {p.name}, reason: {reason}")
try:
kill_process_tree(p)
Expand Down Expand Up @@ -1101,8 +1111,10 @@ def __init__(
# Map of pids to processes
self._processes = {}

# Map of pids to process worker IDs
self._process_worker_ids = {}

# Internal variables of the ProcessPoolExecutor
self._processes = {}
self._queue_count = 0
self._pending_work_items = {}
self._running_work_items = []
Expand Down Expand Up @@ -1183,9 +1195,21 @@ def _start_executor_manager_thread(self):
_python_exit
)

def _get_available_worker_id(self):
if _CURRENT_DEPTH > 0:
return -1

used_ids = set(self._process_worker_ids.values())
available_ids = set(range(self._max_workers)) - used_ids
if len(available_ids):
return available_ids.pop()
else:
return -1

def _adjust_process_count(self):
while len(self._processes) < self._max_workers:
worker_exit_lock = self._context.BoundedSemaphore(1)
worker_id = self._get_available_worker_id()
args = (
self._call_queue,
self._result_queue,
Expand All @@ -1195,8 +1219,10 @@ def _adjust_process_count(self):
self._timeout,
worker_exit_lock,
_CURRENT_DEPTH + 1,
worker_id,
)
worker_exit_lock.acquire()

try:
# Try to spawn the process with some environment variable to
# overwrite but it only works with the loky context for now.
Expand All @@ -1208,6 +1234,7 @@ def _adjust_process_count(self):
p._worker_exit_lock = worker_exit_lock
p.start()
self._processes[p.pid] = p
self._process_worker_ids[p.pid] = worker_id
mp.util.debug(
f"Adjusted process count to {self._max_workers}: "
f"{[(p.name, pid) for pid, p in self._processes.items()]}"
Expand Down
15 changes: 15 additions & 0 deletions loky/worker_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os


def get_worker_id():
"""Get the worker ID of the current process.

For a `ReusableExectutor` with `max_workers=n`, the worker ID is in the
range [0..n). This is suited for reuse of persistent objects such as GPU
IDs. This function only works at the first level of parallelization (i.e.
not for nested parallelization). Resizing the `ReusableExectutor` will
result in unpredictable return values.

Returns -1 when the process is not a worker.
"""
return int(os.environ.get('LOKY_WORKER_ID', -1))
36 changes: 36 additions & 0 deletions tests/test_worker_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import time
import pytest
import numpy as np
from collections import defaultdict
from loky import get_reusable_executor, get_worker_id


def random_sleep(args):
k, max_duration = args
rng = np.random.RandomState(seed=k)
duration = rng.uniform(0, max_duration)
t0 = time.time()
time.sleep(duration)
t1 = time.time()
wid = get_worker_id()
return (wid, t0, t1)


@pytest.mark.parametrize("max_duration,timeout,kmax", [(0.05, 2, 100),
(1, 0.01, 4)])
def test_worker_ids(max_duration, timeout, kmax):
"""Test that worker IDs are always unique, with re-use over time"""
num_workers = 4
executor = get_reusable_executor(max_workers=num_workers, timeout=timeout)
results = executor.map(random_sleep, [(k, max_duration)
for k in range(kmax)])

all_intervals = defaultdict(list)
for wid, t0, t1 in results:
assert wid in set(range(num_workers))
all_intervals[wid].append((t0, t1))

for intervals in all_intervals.values():
intervals = sorted(intervals)
for i in range(len(intervals) - 1):
assert intervals[i + 1][0] >= intervals[i][1]