From 7ef293650a64522faef2517ff1a5e25c3f3713ba Mon Sep 17 00:00:00 2001 From: Grain Team Date: Mon, 1 Dec 2025 10:06:42 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 838825293 --- grain/_src/python/BUILD | 43 - grain/_src/python/data_loader.py | 84 +- grain/_src/python/data_loader_test.py | 76 -- grain/_src/python/dataset/BUILD | 2 +- grain/_src/python/dataset/dataset.py | 9 +- grain/_src/python/dataset/dataset_test.py | 34 - .../dataset/transformations/interleave.py | 8 +- .../dataset/transformations/prefetch.py | 432 +-------- .../dataset/transformations/prefetch_test.py | 580 ------------ .../transformations/process_prefetch.py | 185 ++-- .../transformations/process_prefetch_test.py | 64 +- grain/_src/python/grain_pool.py | 834 ------------------ grain/_src/python/grain_pool_test.py | 471 ---------- grain/_src/python/shared_memory_array.py | 9 + grain/_src/python/shared_memory_array_test.py | 21 + 15 files changed, 224 insertions(+), 2628 deletions(-) delete mode 100644 grain/_src/python/grain_pool.py delete mode 100644 grain/_src/python/grain_pool_test.py diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index cb57e0116..ac1980520 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -208,49 +208,6 @@ py_test( ], ) -py_library( - name = "grain_pool", - srcs = ["grain_pool.py"], - srcs_version = "PY3", - target_compatible_with = select({ - "@platforms//os:windows": ["@platforms//:incompatible"], - "//conditions:default": [], - }), - deps = [ - ":grain_logging", - ":multiprocessing_common", - ":options", - ":record", - ":shared_memory_array", - "//grain/_src/core:config", - "//grain/_src/core:monitoring", - "//grain/_src/core:parallel", - "//grain/_src/core:tree_lib", - "@abseil-py//absl/flags", - "@abseil-py//absl/logging", - "@pypi//cloudpickle:pkg", - ], -) - -py_test( - name = "grain_pool_test", - srcs = ["grain_pool_test.py"], - shard_count = 20, - srcs_version = "PY3", - tags = ["not_run:arm"], - deps = [ - ":data_sources", - ":grain_pool", - ":options", - ":record", - "//grain/_src/core:config", - "//grain/_src/core:monitoring", - "@abseil-py//absl/flags", - "@abseil-py//absl/testing:absltest", - "@abseil-py//absl/testing:parameterized", - ], -) - py_library( name = "checkpoint_handlers", srcs = ["checkpoint_handlers.py"], diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index fffd939b2..7485bfe98 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -24,13 +24,11 @@ import sys from typing import Any, Awaitable, Optional, Sequence, TypeVar -from absl import logging from etils import epath from grain._src.core import monitoring as grain_monitoring from grain._src.core import sharding from grain._src.core import transforms from grain._src.core import tree_lib -import multiprocessing as mp from grain._src.python import checkpointing from grain._src.python import operations as ops from grain._src.python import options @@ -39,8 +37,6 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import batch as batch_ds from grain._src.python.dataset.transformations import flatmap -from grain._src.python.dataset.transformations import prefetch -from grain._src.python.operations import BatchOperation from grain._src.python.operations import Operation from grain._src.python.samplers import Sampler from grain._src.python.shared_memory_array import SharedMemoryArray @@ -260,14 +256,15 @@ def get_state(self): # structure and switch to native dataset checkpointing. This class can be # removed afterwards. dataset_state = self._parent.get_state() - if "workers_state" not in dataset_state: - dataset_state = { - "workers_state": {"0": dataset_state}, - "last_worker_index": -1, - } - workers_state = dataset_state["workers_state"] - last_worker_index = dataset_state["last_worker_index"] - worker_count = len(dataset_state["workers_state"]) + if "iterators_in_use_states" not in dataset_state: + next_index_in_cycle = 0 + workers_state = [dataset_state] + else: + next_index_in_cycle = dataset_state["next_index_in_cycle"] + workers_state = dataset_state["iterators_in_use_states"] + + last_worker_index = next_index_in_cycle - 1 + worker_count = len(workers_state) shard_index = self._shard_options.shard_index if self._shard_options else 0 shard_count = self._shard_options.shard_count if self._shard_options else 1 @@ -278,7 +275,7 @@ def get_state(self): str(i): ( local_offset + i * shard_count - + workers_state[str(i)]["next_index"] * global_worker_count + + workers_state[i]["next_index"] * global_worker_count ) for i in range(worker_count) } @@ -311,25 +308,28 @@ def set_state(self, state): self._parent.set_state(dataset_state) return - iterations_to_skip = {str(i): 0 for i in range(worker_count)} - workers_state = { - str(i): { + iterators_in_use_indices = list(range(worker_count)) + iterators_in_use_states = [ + { "next_index": ( ( - last_seen_indices[str(i)] + last_seen_indices[str(worker_index)] + global_worker_count - shard_index - - i * shard_count + - worker_index * shard_count ) // global_worker_count ) } - for i in range(worker_count) - } + for worker_index in range(worker_count) + ] + dataset_state = { - "workers_state": workers_state, - "iterations_to_skip": iterations_to_skip, - "last_worker_index": last_worker_index, + "next_index_in_cycle": last_worker_index + 1 % worker_count, + "next_index_in_datasets": worker_count, + "iterators_in_use_indices": iterators_in_use_indices, + "iterators_in_use_states": iterators_in_use_states, + "exhausted": [False] * worker_count, } self._parent.set_state(dataset_state) @@ -389,31 +389,17 @@ def __init__( f"Current worker_buffer_size is {worker_buffer_size}." ) - worker_count = _determine_worker_count(worker_count) - if worker_count > 0: - - # Shared memory should be enabled iff worker_count > 0. - # This replaces Batch Transform with a BatchOperation in operations list - # if shared memory is enabled. - if operations and isinstance( - (last_op := operations[-1]), transforms.Batch - ): - logging.info("Creating BatchOperation to enable SharedMemoryArray.") - batch_operation = BatchOperation( - batch_size=last_op.batch_size, - drop_remainder=last_op.drop_remainder, - batch_fn=last_op.batch_fn, + operations = list(operations) + for i in range(len(operations)): + op = operations[i] + if type(op) is ops.BatchOperation: # pylint: disable=unidiomatic-typecheck + operations[i] = transforms.Batch( + batch_size=op.batch_size, + drop_remainder=op.drop_remainder, + batch_fn=op.batch_fn, ) - batch_operation.disable_deprecation_message() - operations = list(operations) - operations[-1] = batch_operation - if operations and isinstance(operations[-1], BatchOperation): - logging.info("Enabling SharedMemoryArray for BatchOperation.") - operations[-1]._enable_shared_memory() - else: - logging.info("Adding CopyNumPyArrayToSharedMemory MapTransform.") - operations = list(operations) + [CopyNumPyArrayToSharedMemory()] + worker_count = _determine_worker_count(worker_count) self._data_source = data_source self._sampler = sampler @@ -471,11 +457,7 @@ def _create_dataset(self) -> dataset.IterDataset: ds = _apply_transform_to_dataset(operation, ds) ds = ds.map(lambda r: r.data) if self.multiprocessing_options.num_workers > 0: - ds = prefetch.MultiprocessPrefetchIterDataset( - ds, - self.multiprocessing_options, - always_report_worker_state=True, - ) + ds = ds.mp_prefetch(self.multiprocessing_options) if not self._use_native_dataset_checkpointing: ds = _DataLoaderStateIterDataset( ds, diff --git a/grain/_src/python/data_loader_test.py b/grain/_src/python/data_loader_test.py index a825ba049..985df9829 100644 --- a/grain/_src/python/data_loader_test.py +++ b/grain/_src/python/data_loader_test.py @@ -705,82 +705,6 @@ def test_batch_transform_mapped_to_batch_operation(self): actual = list(data_loader) np.testing.assert_equal(actual, expected) - @mock.patch.object(data_loader_lib, "CopyNumPyArrayToSharedMemory") - def test_shared_memory_for_batch_operation( - self, mock_copy_numpy_array_to_shared_memory - ): - range_data_source = RangeDataSource(start=0, stop=8, step=1) - sampler = samplers.SequentialSampler( - num_records=len(range_data_source), shard_options=sharding.NoSharding() - ) - - operations = [ - PlusOne(), - FilterEven(), - ] - - batch_operation = mock.MagicMock(BatchOperation(batch_size=2)) - - data_loader = data_loader_lib.DataLoader( - data_source=range_data_source, - sampler=sampler, - operations=operations, - worker_count=0, - read_options=self.read_options, - ) - batch_operation._enable_shared_memory.assert_not_called() - self.assertTrue( - data_loader._operations[-1], mock_copy_numpy_array_to_shared_memory - ) - - data_loader = data_loader_lib.DataLoader( - data_source=range_data_source, - sampler=sampler, - operations=operations + [batch_operation], - worker_count=2, - read_options=self.read_options, - ) - batch_operation._enable_shared_memory.assert_called_once() - self.assertTrue(data_loader._operations[-1], batch_operation) - - @mock.patch.object(BatchOperation, "_enable_shared_memory", autospec=True) - def test_shared_memory_for_batch_transform(self, mock_enable_shared_memory): - range_data_source = RangeDataSource(start=0, stop=8, step=1) - sampler = samplers.SequentialSampler( - num_records=len(range_data_source), shard_options=sharding.NoSharding() - ) - operations = [ - PlusOne(), - FilterEven(), - ] - - data_loader = data_loader_lib.DataLoader( - data_source=range_data_source, - sampler=sampler, - operations=operations, - worker_count=2, - read_options=self.read_options, - ) - mock_enable_shared_memory.assert_not_called() - self.assertIsInstance( - data_loader._operations[-1], - data_loader_lib.CopyNumPyArrayToSharedMemory, - ) - - batch_transform = transforms.Batch(batch_size=2) - - data_loader = data_loader_lib.DataLoader( - data_source=range_data_source, - sampler=sampler, - operations=operations + [batch_transform], - worker_count=2, - read_options=self.read_options, - ) - mock_enable_shared_memory.assert_called_once_with( - data_loader._operations[-1] - ) - self.assertIsInstance(data_loader._operations[-1], BatchOperation) - def test_data_loader_with_batch_fn(self): # Map transforms elements to be [1, 2, 3, 4, 5, 6, 7, 8] # Filter keeps only even elements [2, 4, 6, 8] diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index f93d41dd2..3e6894545 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -53,7 +53,7 @@ py_library( "//grain/_src/core:tree_lib", "//grain/_src/python:checkpointing", "//grain/_src/python:grain_logging", - "//grain/_src/python:grain_pool", + "//grain/_src/python:multiprocessing_common", "//grain/_src/python:options", "//grain/_src/python:shared_memory_array", "//grain/proto:execution_summary_py_pb2", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index ba04c95bf..152bbb8f0 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1353,13 +1353,14 @@ def mp_prefetch( A dataset prefetching input elements in separate processes. """ options = options or grain_options.MultiprocessingOptions(num_workers=10) - # Loaded lazily due to a circular dependency (dataset <-> prefetch). + # Loaded lazily due to a circular dependency (dataset <-> process_prefetch). # pylint: disable=g-import-not-at-top - from grain._src.python.dataset.transformations import prefetch + from grain._src.python.dataset.transformations import process_prefetch # pylint: enable=g-import-not-at-top - return prefetch.MultiprocessPrefetchIterDataset( + return process_prefetch.multiprocess_prefetch( self, - multiprocessing_options=options, + num_workers=options.num_workers, + buffer_size=options.per_worker_buffer_size, worker_init_fn=worker_init_fn, sequential_slice=sequential_slice, ) diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index b46cd1be7..1929f388d 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -1393,40 +1393,6 @@ def test_execution_summary_with_no_logging(self): log_value = "Grain Dataset Execution Summary" self.assertNotIn(log_value, "".join(logs.output)) - @flagsaver.flagsaver(grain_py_debug_mode=True) - @mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05) - def test_execution_summary_with_mp_prefetch(self): - def worker_init_fn_wrapper(worker_index, worker_count): - del worker_index, worker_count - dataset_stats._REPORTING_PERIOD_SEC = 0.05 - - ds = dataset.MapDataset.range(10000).map(MapTransformAddingOne()) - ds = ds.to_iter_dataset() - ds = ds.mp_prefetch( - options.MultiprocessingOptions(num_workers=1), - worker_init_fn=worker_init_fn_wrapper, - ) - it = ds.__iter__() - _ = list(it) - all_nodes_present = False - while not all_nodes_present: - time.sleep(1) - all_nodes_present = True - summary = dataset.get_execution_summary(it) - node_names = {node.name for node in summary.nodes.values()} - all_nodes_present = all_nodes_present and any( - "RangeMapDataset" in name for name in node_names - ) - all_nodes_present = all_nodes_present and any( - "MapMapDataset" in name for name in node_names - ) - all_nodes_present = all_nodes_present and any( - "PrefetchDatasetIterator" in name for name in node_names - ) - all_nodes_present = all_nodes_present and any( - "MultiprocessPrefetchDatasetIterator" in name for name in node_names - ) - class GetElementSpecTest(parameterized.TestCase): diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 722a4d38b..5d387910a 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -28,7 +28,7 @@ T = TypeVar("T") -class _InterleaveDatasetIterator(dataset.DatasetIterator[T]): +class InterleaveDatasetIterator(dataset.DatasetIterator[T]): """Iterates over the interleaved datasets.""" def __init__( @@ -250,7 +250,7 @@ def __str__(self) -> str: def _add_prefetch_and_make_iterator( ds: dataset.IterDataset[T] | dataset.MapDataset[T], - interleave_iterator: weakref.ref[_InterleaveDatasetIterator[T]], + interleave_iterator: weakref.ref[InterleaveDatasetIterator[T]], start_prefetch: bool, ) -> dataset.DatasetIterator[T]: """Adds prefetching to an IterDataset and returns an iterator. @@ -351,8 +351,8 @@ def __init__( self._make_iter_buffer_size = make_iter_buffer_size self._iter_buffer_size = iter_buffer_size - def __iter__(self) -> _InterleaveDatasetIterator[T]: - return _InterleaveDatasetIterator( + def __iter__(self) -> InterleaveDatasetIterator[T]: + return InterleaveDatasetIterator( self._datasets, cycle_length=self._cycle_length, num_make_iter_threads=self._num_make_iter_threads, diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index e7a6392fa..dfa46e3d9 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -16,35 +16,24 @@ from __future__ import annotations import collections -from collections.abc import Callable, Iterator, Sequence -import contextlib +from collections.abc import Iterator, Sequence import copy import functools -import math from multiprocessing import queues -from multiprocessing import synchronize import queue -import sys import threading -import time import typing -from typing import Any, Generic, Optional, Protocol, TypeVar +from typing import Any, Optional, Protocol, TypeVar -import cloudpickle from concurrent import futures from grain._src.core import monitoring as grain_monitoring -from grain._src.core import tree_lib -import multiprocessing as mp -from grain._src.python import grain_pool from grain._src.python import options as grain_options -from grain._src.python import shared_memory_array from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats from grain._src.python.dataset.transformations import filter as filter_dataset from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import source -import numpy as np T = TypeVar("T") @@ -328,131 +317,6 @@ def close(self) -> None: future.cancel() -def _iterator_with_context( - iterator: contextlib.AbstractContextManager[Iterator[T]], -) -> Iterator[T]: - with iterator as it: - yield from it - - -def _validate_no_double_prefetch( - parent: dataset.MapDataset | dataset.IterDataset, -) -> None: - """Checks that there are no multiple levels of parallelization.""" - to_check: list[dataset.MapDataset | dataset.IterDataset] = [parent] - while to_check: - ds = to_check.pop(0) - if isinstance(ds, MultiprocessPrefetchIterDataset): - raise ValueError( - "Nesting multiprocessing or multithreading is not allowed." - ) - to_check.extend(ds.parents) - - -class MultiprocessPrefetchIterDataset(dataset.IterDataset[T]): - """Uses a pool of processes to prefetch elements ahead of time. - - It usually makes sense to add this transformation in the end of the pipeline - since it will execute the parent IterDataset in multiple processes. - """ - - def __init__( - self, - parent: dataset.IterDataset[T], - multiprocessing_options: grain_options.MultiprocessingOptions, - worker_init_fn: Callable[[int, int], None] | None = None, - sequential_slice: bool = False, - always_report_worker_state: bool = False, - ): - if multiprocessing_options.num_workers < 0: - raise ValueError( - "`num_workers` must be greater than or equal to 0, got " - f"{multiprocessing_options.num_workers}." - ) - super().__init__(parent) - self._multiprocessing_options = multiprocessing_options - self._worker_init_fn = worker_init_fn - self._sequential_slice = sequential_slice - _validate_no_double_prefetch(self._parent) - self._always_report_worker_state = always_report_worker_state - - def __str__(self) -> str: - return ( - "MultiprocessPrefetchIterDataset(" - f"multiprocessing_options={self._multiprocessing_options})" - ) - - def __iter__(self) -> dataset.DatasetIterator[T]: - if self._multiprocessing_options.num_workers == 0: - return self._parent.__iter__() - return _MultiprocessPrefetchDatasetIterator( - self._parent, - self._multiprocessing_options, - self._worker_init_fn, - self._sequential_slice, - self._always_report_worker_state, - ) - - @property - def _element_spec(self) -> Any: - return dataset.get_element_spec(self._parent) - - -# Keys in `MultiprocessPrefetchDatasetIterator` checkpoints. -_WORKERS_STATE = "workers_state" -_ITERATIONS_TO_SKIP = "iterations_to_skip" -_LAST_WORKER_INDEX = "last_worker_index" - -# Minimal interval (in seconds) between consecutive state recordings in worker -# processes of `MultiprocessPrefetchDatasetIterator`. We record the state -# periodically to reduce the overhead of sending the state from workers. -# Note that this is also an approximate upper bound on how long it is going to -# take to recover from a checkpointed state. Larger values will decrease the -# overhead of sending the updated state but will also make recovery from a -# checkpoint longer on average. -_RECORD_STATE_INTERVAL_S = 3 - - -def _copy_leaf_to_shm(leaf: Any, min_size: int = 0) -> Any: - """Copies `leaf` to shared memory if it's a big enough numpy array.""" - if isinstance(leaf, shared_memory_array.SharedMemoryArray): - return leaf.metadata - if ( - not isinstance(leaf, np.ndarray) - or leaf.dtype.hasobject - or not leaf.flags.c_contiguous - or math.prod(leaf.shape) == 0 - or leaf.nbytes < min_size - ): - return leaf - - shared_memory_arr = shared_memory_array.SharedMemoryArray( - leaf.shape, leaf.dtype - ) - np.copyto(shared_memory_arr, leaf, casting="no") - return shared_memory_arr.metadata - - -def _copy_struct_to_shm(struct: Any, min_size: int = 0) -> Any: - """Copies leaf ndarrays of the structure to shared memory.""" - return tree_lib.map_structure( - functools.partial(_copy_leaf_to_shm, min_size=min_size), struct - ) - - -def _open_leaf_from_shm(leaf: Any) -> Any: - """Recovers `leaf` from shared memory if it's a numpy array metadata.""" - if isinstance(leaf, shared_memory_array.SharedMemoryArrayMetadata): - leaf = shared_memory_array.SharedMemoryArray.from_metadata(leaf) - leaf.unlink_on_del() - return leaf - - -def _open_struct_from_shm(struct: Any) -> Any: - """Recovers leaf ndarrays of the structure from shared memory.""" - return tree_lib.map_structure(_open_leaf_from_shm, struct) - - def _set_slice_iter_dataset( ds: dataset.IterDataset, sl: slice, @@ -509,123 +373,6 @@ def _set_slice_map_dataset( _set_slice_iter_dataset(parent, sl, sequential_slice) -def _check_picklable( - ds: dataset.IterDataset | dataset.MapDataset, -): - """Detects the first unpickle-able dataset in post-order. - - Args: - ds: IterDataset or MapDataset to check whether it is picklable. - - NOTE: This function's time complexity is O(n^2) where n is the number of - Grain dataset operations because `cloudpickle.dumps(ds)` will trigger - pickling into all the datasets. If this naive O(n^2) algorithm takes too - much time, we could consider doing copying `ds`, delete its parents and then - do `cloudpickle.dumps(new_ds)` to reduce the time complexity to O(n). - """ - - # Traverses the graph in post-order to find the first unpickle-able subtree - for parent in ds.parents: - _check_picklable(parent) - - try: - cloudpickle.dumps(ds) - except Exception as e: # pylint: disable=broad-exception-caught - if sys.version_info >= (3, 11): - e.add_note( - f"Dataset: {ds} cannot be pickled!" - ) - raise e - - -class GetElementProducerFn(grain_pool.GetElementProducerFn, Generic[T]): - """Implements `GetElementProducerFn` for `grain_pool.MultiProcessIterator`. - - This class implements `GetElementProducerFn` with `serialize` being overridden - to generate better error messages if user-provided dataset is not pickle-able. - """ - - def __init__( - self, - state: dict[str, dict[str, Any] | int], - ds: dataset.IterDataset[T], - sequential_slice: bool = False, - always_report_worker_state: bool = False, - ): - self._state = state - self._ds = ds - self._sequential_slice = sequential_slice - self._always_report_worker_state = always_report_worker_state - - def __call__( - self, - *, - worker_index: int, - worker_count: int, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - stats_out_queue: queues.Queue | None = None, - ) -> Iterator[tuple[T, Optional[dict[str, Any]]]]: - if worker_count > 1: - _set_slice_iter_dataset( - self._ds, - slice(worker_index, None, worker_count), - self._sequential_slice, - ) - # Prevent OutputDatasetIterator injection in worker processes. - # The injection should only happen in the main process iterator, - # which wraps the _MultiprocessPrefetchDatasetIterator. - it = self._ds.__iter__() - it._ctx.mp_context = base.MultiprocessingContext( - process_index=worker_index, process_count=worker_count - ) - min_shm_size = it._ctx.dataset_options.min_shm_size - # Recover from the last recorded state for the given worker. - worker_state = self._state[_WORKERS_STATE][str(worker_index)] - if worker_state is not None: - it.set_state(worker_state) - # Set the stats queue in worker process to send stats to the main process. - it._stats._config.stats_out_queue = stats_out_queue # pytype: disable=attribute-error - # Skip the required number of iterations after the last recorded state. - for _ in range(self._state[_ITERATIONS_TO_SKIP][str(worker_index)]): - _ = next(it) - last_recorded_state_time = time.time() - for element in it: - now = time.time() - element = _copy_struct_to_shm(element, min_size=min_shm_size) - # If the node is prefetch, we already record the bytes produced in it's - # __next__ method. - if not it._stats._config.is_prefetch: - it._stats.record_bytes_produced(element) - if ( - self._always_report_worker_state - or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S - ): - last_recorded_state_time = now - yield (element, it.get_state()) # pytype: disable=attribute-error - else: - yield (element, None) - - def serialize(self) -> bytes: - """Overrides the default implementation to generate better error messages.""" - - try: - return cloudpickle.dumps(self) - except Exception as e: # pylint: disable=broad-except - # Calls `_check_picklable` to generate useful pickle errors - # - # Note: No need to check `self._state` because it should not generate - # unpicklable errors and it is controlled by us, not from user's code - # in most cases. Except for the case when users try to implement their own - # `MapDataset` and `IterDataset` with custom pickle-ing logic that - # contains unpickle-able objects. - _check_picklable(self._ds) - - # If somehow we cannot find the dataset that is causing the pickle - # issues, just raise the original error - raise e - - def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: result = base.DatasetOptions() to_visit = [ds] @@ -637,175 +384,6 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: return result -class _MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]): - """Iterator that performs prefetching using a multiprocessing pool.""" - - def __init__( - self, - parent: dataset.IterDataset[T], - multiprocessing_options: grain_options.MultiprocessingOptions, - worker_init_fn: Callable[[int, int], None] | None = None, - sequential_slice: bool = False, - always_report_worker_state: bool = False, - ): - super().__init__() - self._iter_parent = parent - # Since the parent iterator is going to be created in each subprocess, and - # the options are propagated during iterator creation, we need to manually - # propagate them. - self._ctx.dataset_options = _get_dataset_options(parent) - self._multiprocessing_options = multiprocessing_options - self._worker_init_fn = worker_init_fn - self._sequential_slice = sequential_slice - # The underlying iterator producing elements and workers state. - self._iterator = None - # Raw reference to the underlying iterator that can be used to determine the - # last worker index. - self._raw_iterator = None - # Create initial state. We record state of each worker periodically together - # with the number of iterations without the recorded state and index of the - # last worker. - iterations_to_skip: dict[str, int] = { - str(i): 0 for i in range(multiprocessing_options.num_workers) - } - workers_state: dict[str, Any] = { - str(i): None for i in range(multiprocessing_options.num_workers) - } - self._stats_in_queues = tuple( - mp.get_context("spawn").Queue(maxsize=5) - for _ in range(multiprocessing_options.num_workers) - ) - self._start_profiling_event = mp.get_context("spawn").Event() - self._stop_profiling_event = mp.get_context("spawn").Event() - - self._state: dict[str, dict[str, Any] | int] = { - _WORKERS_STATE: workers_state, - _ITERATIONS_TO_SKIP: iterations_to_skip, - _LAST_WORKER_INDEX: -1, - } - - self._always_report_worker_state = always_report_worker_state - - def _initialize_stats( - self, execution_tracking_mode: base.ExecutionTrackingMode - ): - self._stats = _initialize_prefetch_stats( - self, - execution_tracking_mode, - parent_stats=[], - stats_in_queues=self._stats_in_queues, - ) - return self._stats - - @functools.cached_property - def _stats(self): - return self._initialize_stats( - self._ctx.dataset_options.execution_tracking_mode - ) - - def __iter__(self) -> dataset.DatasetIterator[T]: - return self - - @dataset_stats.record_next_duration_if_output - @dataset_stats.trace_input_pipeline_next( - stage_category=dataset_stats.IPL_CAT_PREFETCH - ) - def __next__(self) -> T: - self._assert_not_closed() - self._ensure_iterator_initialized() - # The time recorded here is the time spent in prefetch node to return an - # element, including the time spent in parent node. - timer = dataset_stats.Timer() - result, state = next(self._iterator) - with self._stats.record_self_time(offset_ns=timer.value()): - worker_index = self._raw_iterator.get_last_worker_index() # pytype: disable=attribute-error - - # pytype: disable=annotation-type-mismatch - iterations_to_skip: dict[str, Any] = self._state[_ITERATIONS_TO_SKIP] - worker_state: dict[str, Any] = self._state[_WORKERS_STATE] - # pytype: enable=annotation-type-mismatch - - self._state[_LAST_WORKER_INDEX] = worker_index - worker_index_str = str(worker_index) - if state is None: - iterations_to_skip[worker_index_str] += 1 - else: - iterations_to_skip[worker_index_str] = 0 - worker_state[worker_index_str] = state - result = self._stats.record_bytes_produced(result) - return _open_struct_from_shm(result) - - def start_prefetch(self) -> None: - """Prefetches elements from the iterator. - - This will run background processes for prefetching. To make sure to clean up - the resources, it should be followed by at least one `next` call. - """ - self._ensure_iterator_initialized() - - def set_state(self, state: dict[str, dict[str, Any] | int]) -> None: - self._state = state - self._raw_iterator = None - self._iterator = None - - def get_state(self) -> dict[str, Any]: - result = copy.deepcopy(self._state) - workers_state: dict[str, Any] = result[_WORKERS_STATE] # pytype: disable=annotation-type-mismatch - parent_state = None - for worker_index, worker_state in workers_state.items(): - # Create initial state from the parent iterator. This is to make sure the - # spec of the produced iterator does not change. - if worker_state is None: - parent_state = parent_state or self._iter_parent.__iter__().get_state() - workers_state[worker_index] = copy.deepcopy(parent_state) - return result - - def _ensure_iterator_initialized(self) -> None: - if self._iterator is None: - self._raw_iterator = self._create_iterator_context() - self._raw_iterator.start_prefetch() - self._iterator = _iterator_with_context(self._raw_iterator) - - def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]: - """Creates a `MultiProcessIterator`.""" - # Apply the latest options to the subprocess dataset. We delay this until - # starting subprocesses because child iterators may update them. - ds = dataset.WithOptionsIterDataset( - self._iter_parent, self._ctx.dataset_options - ) - get_element_producer_fn = GetElementProducerFn( - self._state, - ds, - self._sequential_slice, - self._always_report_worker_state, - ) - - return grain_pool.MultiProcessIterator( - get_element_producer_fn, - self._multiprocessing_options, - (self._state[_LAST_WORKER_INDEX] + 1) - % self._multiprocessing_options.num_workers, - self._worker_init_fn, - self._start_profiling_event, - self._stop_profiling_event, - self._stats_in_queues, - ) - - def __str__(self) -> str: - return ( - "MultiprocessPrefetchDatasetIterator(" - f"multiprocessing_options={self._multiprocessing_options})" - ) - - def close(self) -> None: - """Shuts down the prefetching threads and multiprocessing pool.""" - if self._closed: - return - self._closed = True - if self._raw_iterator is not None: - self._raw_iterator.stop_prefetch() - - class ThreadPrefetchIterDataset(dataset.IterDataset[T]): """Iterable dataset that uses a synchronized queue for prefetching. @@ -978,6 +556,8 @@ def close(self): """Stops the iterator. No further calls to the iterator are expected.""" self._closed = True self._stop_prefetch() + if isinstance(self._maybe_nonnative_parent, dataset.DatasetIterator): + self._maybe_nonnative_parent.close() def _clear_buffer(self): while True: @@ -1068,8 +648,6 @@ def multithread_prefetch( if num_threads == 0: return ds - _validate_no_double_prefetch(ds) - shards = [] for i in range(num_threads): worker_ds = copy.deepcopy(ds) @@ -1098,6 +676,6 @@ def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool: ( PrefetchDatasetIterator, ThreadPrefetchDatasetIterator, - _MultiprocessPrefetchDatasetIterator, + interleave.InterleaveDatasetIterator, ), ) diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 3e45a8ad6..060deec13 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -11,17 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from concurrent import futures import dataclasses -import logging as std_logging import platform import sys import threading import time from typing import TypeVar, cast -from unittest import mock -from absl import logging from absl.testing import absltest from absl.testing import parameterized from grain._src.core import transforms @@ -433,549 +429,6 @@ def test_element_spec(self): self.assertEqual(spec.dtype, np.int64) -class MultiprocessPrefetchIterDatasetTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - ds = dataset.MapDataset.range(20) - ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) - self.iter_ds = ds.filter(FilterKeepingOddElementsOnly()) - - @parameterized.named_parameters( - dict( - testcase_name='0_workers', - num_workers=0, - per_worker_buffer_size=1, - ), - dict( - testcase_name='1_worker', - num_workers=1, - per_worker_buffer_size=1, - ), - dict( - testcase_name='1_worker_large_buffer', - num_workers=1, - per_worker_buffer_size=20, - ), - dict( - testcase_name='10_workers', - num_workers=10, - per_worker_buffer_size=1, - ), - dict( - testcase_name='10_workers_large_buffer', - num_workers=10, - per_worker_buffer_size=20, - ), - ) - def test_prefetch_data(self, num_workers: int, per_worker_buffer_size: int): - prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers, per_worker_buffer_size), - ) - actual = list(prefetch_lazy_iter_ds) - expected = list(range(1, 20, 2)) - self.assertSequenceEqual(actual, expected) - - def test_prefetch_size_zero_data(self): - ds = dataset.MapDataset.source( - [np.zeros(shape=(0,), dtype=np.int64)] - ).repeat(3) - iter_ds = ds.to_iter_dataset() - prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset( - iter_ds, - options.MultiprocessingOptions(num_workers=1), - ) - actual = list(prefetch_lazy_iter_ds) - expected = [np.zeros(shape=(0,), dtype=np.int64)] * 3 - self.assertLen(actual, 3) - self.assertLen(expected, 3) - for i in range(3): - np.testing.assert_array_equal(actual[i], expected[i]) - - @parameterized.product( - ( - dict( - num_workers=0, - record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, - ), - dict( - num_workers=1, - record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, - ), - dict( - num_workers=10, - record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, - ), - dict( - num_workers=10, - record_state_interval=0, - ), - ), - step_index=[0, 3, 8], - ) - def test_checkpoint( - self, num_workers: int, record_state_interval: int, step_index: int - ): - with mock.patch.object( - prefetch, '_RECORD_STATE_INTERVAL_S', record_state_interval - ): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers), - ) - ds_iter = ds.__iter__() - - max_steps = 10 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) - - ds_iter.set_state(checkpoints[step_index]) - for i in range(step_index, max_steps): - value = next(ds_iter) - self.assertEqual(value, values_without_interruption[i]) - - def test_set_state_twice(self): - with mock.patch.object(prefetch, '_RECORD_STATE_INTERVAL_S', 0): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(2), - ) - ds_iter = ds.__iter__() - - max_steps = 10 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) - - for starting_step in [0, 3, 8]: - ds_iter.set_state(checkpoints[starting_step]) - for i in range(starting_step, max_steps): - value = next(ds_iter) - self.assertEqual(value, values_without_interruption[i]) - - def test_fails_with_negative_num_workers(self): - with self.assertRaisesRegex( - ValueError, '`num_workers` must be greater than or equal to 0' - ): - prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers=-1), - ) - - def test_fails_with_multiple_prefetches(self): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers=10), - ) - with self.assertRaisesRegex( - ValueError, - 'Nesting multiprocessing or multithreading is not allowed.', - ): - _ = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers=1), - ) - - def test_works_with_iter_source_single_worker(self): - # Even though a pure IterDataset cannot be sliced, we should still be able - # to multiprocess-prefetch it with a single worker, since that doesn't - # require any slicing. - ds = prefetch.MultiprocessPrefetchIterDataset( - RepeatedIntSourceIterDataset().map(lambda x: x + 1), - options.MultiprocessingOptions(num_workers=1), - ) - ds_iter = iter(ds) - self.assertEqual(next(ds_iter), 2) - - def test_fails_with_iter_source_multiple_workers(self): - ds = prefetch.MultiprocessPrefetchIterDataset( - RepeatedIntSourceIterDataset().map(lambda x: x + 1), - options.MultiprocessingOptions(num_workers=2), - ) - ds_iter = iter(ds) - - with self.assertRaisesRegex( - Exception, - 'Cannot slice `IterDataset` source.', - ): - next(ds_iter) - - def test_propagates_transform_error(self): - error_msg = 'I shall fail!' - - def failing_transform(element): - del element - raise ValueError(error_msg) - - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds.map(failing_transform), - options.MultiprocessingOptions(num_workers=1), - ) - with self.assertRaisesRegex(Exception, error_msg): - list(ds) - - def test_reports_worker_crash(self): - def failing_transform(element): - del element - sys.exit(123) - - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds.map(failing_transform), - options.MultiprocessingOptions(num_workers=1), - ) - with self.assertRaisesRegex( - RuntimeError, 'was terminated unexpectedly with exit code 123' - ): - list(ds) - - def test_reports_unpicklable_transform(self): - class UnpicklableObject: - - def __getstate__(self): - raise ValueError('UnpicklableObject is not picklable') - - local_state = UnpicklableObject() - - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds.map(lambda _: 1 if local_state is None else 2), - options.MultiprocessingOptions(num_workers=1), - ) - with self.assertRaisesRegex( - ValueError, 'UnpicklableObject is not picklable' - ) as context_manager: - list(ds) - - if sys.version_info >= (3, 11): - self.assertRegex( - ''.join(context_manager.exception.__notes__), - r'Dataset: MapIterDataset.* cannot be pickled!', - ) - - def test_reports_first_unpicklable_dataset_when_with_multiple_parents(self): - class UnpicklableObject: - - def __getstate__(self): - raise ValueError('UnpicklableObject is not picklable') - - local_unpicklable_obj = UnpicklableObject() - - class LeftTransform(transforms.MapTransform): - - def map(self, x): - return x if local_unpicklable_obj else x - - class RightTransform(transforms.MapTransform): - - def map(self, x): - return x if local_unpicklable_obj else x - - ds_left = dataset.MapDataset.range(0, 10) - ds_left = ds_left.map(LeftTransform()) - ds_right = dataset.MapDataset.range(10, 20) - ds_right = ds_right.map(RightTransform()) - - ds = dataset.MapDataset.mix([ds_left, ds_right], [1.0, 1.0]) - - iter_ds = ds.to_iter_dataset( - read_options=options.ReadOptions(prefetch_buffer_size=0) - ) - iter_ds = iter_ds.mp_prefetch() - - with self.assertRaisesRegex( - ValueError, - r'UnpicklableObject is not picklable', - ) as context_manager: - list(iter_ds) - - if sys.version_info >= (3, 11): - self.assertRegex( - ''.join(context_manager.exception.__notes__), - r'Dataset: MapMapDataset\(transform=LeftTransform\) cannot be' - r' pickled!', - ) - - def test_reports_unpicklable_issue_when_only_one_parent_unpicklable(self): - class UnpicklableObject: - - def __getstate__(self): - raise ValueError('UnpicklableObject is not picklable') - - class PickleableTransform(transforms.MapTransform): - - def map(self, x): - return x - - local_unpicklable_obj = UnpicklableObject() - - class RightTransform(transforms.MapTransform): - - def map(self, x): - return x if local_unpicklable_obj else x - - ds_left = dataset.MapDataset.range(0, 10) - ds_left = ds_left.map(PickleableTransform()) - ds_right = dataset.MapDataset.range(10, 20) - ds_right = ds_right.map(RightTransform()) - - ds = dataset.MapDataset.mix([ds_left, ds_right], [1.0, 1.0]) - - iter_ds = ds.to_iter_dataset( - read_options=options.ReadOptions(prefetch_buffer_size=0) - ) - iter_ds = iter_ds.mp_prefetch() - - with self.assertRaisesRegex( - ValueError, 'UnpicklableObject is not picklable' - ) as context_manager: - list(iter_ds) - - if sys.version_info >= (3, 11): - self.assertRegex( - ''.join(context_manager.exception.__notes__), - r'Dataset: MapMapDataset\(transform=RightTransform\) cannot be' - r' pickled!', - ) - - @parameterized.product( - start_prefetch_calls=[0, 1, 10], - num_workers=[6], - per_worker_buffer_size=[1, 20], - ) - def test_start_prefetch( - self, - start_prefetch_calls: int, - num_workers: int, - per_worker_buffer_size: int, - ): - class _SleepTransform(transforms.MapTransform): - - def map(self, features): - time.sleep(1) - return features - - ds = dataset.MapDataset.range(10) - ds = ds.map(_SleepTransform()) - ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) - ds = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers, per_worker_buffer_size), - ) - - it = ds.__iter__() - for _ in range(start_prefetch_calls): - it.start_prefetch() - - # Waits for prefetching. - start_time = time.time() - while time.time() - start_time < 30: - time.sleep(2) - - # Measures time to read from the dataset. - start_time = time.time() - self.assertSequenceEqual(list(it), list(range(10))) - - time_to_fetch = time.time() - start_time - logging.info('Reading dataset took %.2f seconds.', time_to_fetch) - # Note that we can't reliably assert the upper bound on the time it takes - # read the dataset elements since worker startup time can vary a lot. - if not start_prefetch_calls: - self.assertGreater(time_to_fetch, 1) - - @parameterized.parameters(0, 0.5, 30) - def test_prefetch_but_no_read(self, sleep_s): - ds = dataset.MapDataset.source([1, 2, 3]).repeat() - ds = ds.filter(lambda x: x > 3) - ds = ds.to_iter_dataset() - ds = ds.mp_prefetch() - it = ds.__iter__() - it.start_prefetch() - time.sleep(sleep_s) - del it - - def test_prefetch_with_random_map(self): - ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset() - ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42) - ds = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers=5), - ) - # Make sure that sliced datasets on workers are seeded differently and thus - # produce different random elements. - elements = list(ds) - distinct_elements = set(elements) - self.assertLen(distinct_elements, len(elements)) - - def test_concurrent_start_prefetch(self): - num_iters = 10 # Can't set this much higher without Forge OOMing. - - def make_iter(i): - ds = dataset.MapDataset.source([i]) - ds = ds.to_iter_dataset() - ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1)) - return ds.__iter__() - - iters = [make_iter(i) for i in range(num_iters)] - with futures.ThreadPoolExecutor(max_workers=num_iters) as executor: - for it in iters: - executor.submit(it.start_prefetch) - for it in iters: - _ = next(it) - - def test_options_before_prefetch(self): - ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) - ds = ds.to_iter_dataset() - ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) - ds = dataset.WithOptionsIterDataset(ds, ds_options) - ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) - ds = ds.filter(lambda x: x > 2) - with self.assertRaises(Exception): - list(ds) - - def test_multiprocess_prefetch_with_sequential_slice(self): - ds = dataset.MapDataset.source(range(10)).to_iter_dataset() - ds = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - sequential_slice=True, - ) - self.assertEqual(list(ds), [0, 4, 7, 1, 5, 8, 2, 6, 9, 3]) - - def test_multiprocess_prefetch_with_default_slice_non_sequential(self): - ds = dataset.MapDataset.source(range(10)).to_iter_dataset() - ds_sequential_off = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - sequential_slice=False, - ) - ds_sequential_default = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - ) - elements_sequential_off = list(ds_sequential_off) - elements_sequential_default = list(ds_sequential_default) - self.assertEqual( - elements_sequential_off, - elements_sequential_default, - ) - self.assertEqual( - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - elements_sequential_default, - ) - - def test_multiprocess_prefetch_sequential_slice_order_from_source(self): - ds = dataset.MapDataset.source(range(10)).to_iter_dataset() - ds_sequential_on = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - sequential_slice=True, - ) - elements_sequential_on = list(ds_sequential_on) - self.assertEqual([0, 4, 7, 1, 5, 8, 2, 6, 9, 3], elements_sequential_on) - - def test_multiprocess_prefetch_sequential_slice_order_from_range(self): - ds_range = dataset.MapDataset.range(10).to_iter_dataset() - ds_range_sequential_on = prefetch.MultiprocessPrefetchIterDataset( - ds_range, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - sequential_slice=True, - ) - elements_range_sequential_on = list(ds_range_sequential_on) - self.assertEqual( - [0, 4, 7, 1, 5, 8, 2, 6, 9, 3], - elements_range_sequential_on, - ) - - def test_multiprocess_prefetch_sequential_slice_order_from_range_slice(self): - ds_range = dataset.MapDataset.range( - start=2, stop=21, step=3 - ).to_iter_dataset() - ds_range_sequential_on = prefetch.MultiprocessPrefetchIterDataset( - ds_range, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - sequential_slice=True, - ) - elements_range_sequential_on = list(ds_range_sequential_on) - self.assertEqual( - [2, 11, 17, 5, 14, 20, 8], - elements_range_sequential_on, - ) - - def test_multiprocess_prefetch_sequential_slice_order_same(self): - ds_source = dataset.MapDataset.source(range(10)).to_iter_dataset() - ds_range = dataset.MapDataset.range(10).to_iter_dataset() - ds_source_mp = prefetch.MultiprocessPrefetchIterDataset( - ds_source, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - sequential_slice=True, - ) - ds_range_mp = prefetch.MultiprocessPrefetchIterDataset( - ds_range, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), - sequential_slice=True, - ) - elements_source = list(ds_source_mp) - elements_range = list(ds_range_mp) - self.assertEqual(elements_source, elements_range) - - def test_options_after_prefetch(self): - ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) - ds = ds.filter(lambda x: x > 2) - ds = ds.to_iter_dataset() - ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) - ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) - ds = dataset.WithOptionsIterDataset(ds, ds_options) - with self.assertRaises(Exception): - list(ds) - - def test_worker_init_fn(self): - def set_worker_index_and_count(worker_index: int, worker_count: int): - log_formatter = std_logging.Formatter( - f'[Worker {worker_index} out of {worker_count}] %(message)s' - ) - logging.get_absl_handler().setFormatter(log_formatter) - - def map_fn(x): - # absl logging from workers is not propagated to the main process in unit - # tests. Therefore, we manually pass the formatted log message. - record = logging.get_absl_logger().makeRecord( - 'grain', - logging.INFO, - 'grain_pool_test', - 123, - f'processing element {x}', - (), - None, - ) - return logging.get_absl_handler().format(record) - - ds = dataset.MapDataset.range(2).map(map_fn) - ds = ds.to_iter_dataset() - ds = ds.mp_prefetch( - options.MultiprocessingOptions(num_workers=2), - worker_init_fn=set_worker_index_and_count, - ) - self.assertEqual( - list(ds), - [ - '[Worker 0 out of 2] processing element 0', - '[Worker 1 out of 2] processing element 1', - ], - ) - - def test_element_spec(self): - ds = dataset.MapDataset.range(2).to_iter_dataset() - ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=2)) - spec = dataset.get_element_spec(ds) - self.assertEqual(spec.dtype, np.int64) - self.assertEqual(spec.shape, ()) - - class ThreadPrefetchIterDatasetTest(parameterized.TestCase): def setUp(self): @@ -1359,24 +812,6 @@ def test_set_state_on_fresh_iterator(self): value = next(ds_iter) self.assertEqual(value, values_without_interruption[i]) - def test_get_state_doesnt_start_prefetch(self): - event = threading.Event() - - def f(x): - event.set() - return x - - ds = dataset.MapDataset.source([1, 2, 3]).map(f).to_iter_dataset() - ds = prefetch.multithread_prefetch( - ds, - num_threads=2, - buffer_size=10, - ) - it = ds.__iter__() - it.get_state() - time.sleep(1) - self.assertFalse(event.is_set()) - def test_does_not_hang_after_stop_iteration(self): ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() ds = prefetch.multithread_prefetch( @@ -1387,21 +822,6 @@ def test_does_not_hang_after_stop_iteration(self): it = ds.__iter__() it.start_prefetch() - def test_fails_with_multiprocess_prefetch_parent(self): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.ds, - options.MultiprocessingOptions(num_workers=2), - ) - with self.assertRaisesRegex( - ValueError, - 'Nesting multiprocessing or multithreading is not allowed.', - ): - _ = prefetch.multithread_prefetch( - ds, - num_threads=1, - buffer_size=1, - ) - def test_mp_context_is_set_correctly(self): num_workers = 4 ds = dataset.MapDataset.range(20).to_iter_dataset() diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 40afa2d8e..d9653d5cb 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -16,7 +16,6 @@ from __future__ import annotations from collections.abc import Callable, Sequence -import copy import functools from multiprocessing import queues from multiprocessing import synchronize @@ -29,6 +28,7 @@ from grain._src.core.config import config import multiprocessing as mp from grain._src.python import grain_logging +from grain._src.python import multiprocessing_common from grain._src.python import shared_memory_array from grain._src.python.dataset import base from grain._src.python.dataset import dataset @@ -36,24 +36,12 @@ from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import prefetch + T = TypeVar("T") # Type for the iterator state. StateT = dict[str, Any] -# Minimal interval (in seconds) between consecutive state recordings in worker -# processes of `_ProcessPrefetchDatasetIterator`. We record the state -# periodically to reduce the overhead of sending the state from workers. -# Note that this is also an approximate upper bound on how long it is going to -# take to recover from a checkpointed state. Larger values will decrease the -# overhead of sending the updated state but will also make recovery from a -# checkpoint longer on average. -_RECORD_STATE_INTERVAL_S = 3 - -# Keys in `_ProcessPrefetchDatasetIterator` checkpoints. -_WORKER_STATE = "worker_state" -_ITERATIONS_TO_SKIP = "iterations_to_skip" - # Timeout for killing worker processes on iterator close. _PROCESS_KILL_TIMEOUT_S = 10 # Interval to wait in the worker process when the parent iterator is exhausted @@ -62,6 +50,8 @@ # Timeout for getting an element from the worker process. _QUEUE_WAIT_TIMEOUT_S = 1 +_is_in_worker_process = False + def _run_all(fns: Sequence[Callable[[], None]]): for fn in fns: @@ -90,27 +80,6 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: return result -def _validate_no_nested_process_prefetch( - ds: dataset.MapDataset | dataset.IterDataset, -): - """Checks that there are no nested process prefetch nodes.""" - to_check: list[dataset.MapDataset | dataset.IterDataset] = [ds] - while to_check: - d = to_check.pop(0) - if isinstance( - d, - ( - ProcessPrefetchIterDataset, - prefetch.MultiprocessPrefetchIterDataset, - ), - ): - raise ValueError( - "Nesting prefetching with processes is not allowed, but found " - f"{type(d).__name__} under a ProcessPrefetchIterDataset." - ) - to_check.extend(d.parents) - - def _check_picklable( ds: dataset.IterDataset | dataset.MapDataset, ): @@ -149,8 +118,22 @@ def _serialize_dataset(ds: dataset.IterDataset) -> bytes: raise e +def _clear_queue_and_maybe_unlink_shm(q: queues.Queue[Any]) -> int: + count = 0 + while True: + try: + shared_memory_array.unlink_shm(q.get_nowait()) + count += 1 + except queue.Empty: + return count + + class ProcessPrefetchIterDataset(dataset.IterDataset[T]): - """Iterable dataset that uses a background process for prefetching.""" + """Iterable dataset that uses a background process for prefetching. + + This dataset transformation accepts an IterDataset and prefetches elements + from it in a separate process, buffering up to `buffer_size` elements. + """ def __init__( self, @@ -158,6 +141,14 @@ def __init__( buffer_size: int, worker_init_fn: Callable[[], None] | None = None, ): + """Initializes the ProcessPrefetchIterDataset. + + Args: + parent: The dataset to prefetch from. + buffer_size: The size of the buffer used for prefetching. + worker_init_fn: An optional function to run in the worker process at + startup. + """ if buffer_size <= 0: raise ValueError( f"`buffer_size` must be greater than 0, got {buffer_size}." @@ -165,7 +156,6 @@ def __init__( super().__init__(parent) self._buffer_size = buffer_size self._worker_init_fn = worker_init_fn - _validate_no_nested_process_prefetch(self._parent) def __str__(self) -> str: return f"ProcessPrefetchIterDataset(buffer_size={self._buffer_size})" @@ -196,6 +186,8 @@ def _put_dataset_elements_in_buffer( debug_flags: dict[str, Any], ): """Prefetches elements in a separate process.""" + global _is_in_worker_process + _is_in_worker_process = True try: parse_debug_flags_fn = cloudpickle.loads(pickled_parse_debug_flags_fn) parse_debug_flags_fn(debug_flags) @@ -212,12 +204,13 @@ def _put_dataset_elements_in_buffer( if set_state_event.is_set(): set_state_event.clear() parent_exhausted = False - new_state, iterations_to_skip_after_set_state = set_state_queue.get() + new_state = set_state_queue.get() if new_state is not None: it.set_state(new_state) - for _ in range(iterations_to_skip_after_set_state): - _ = next(it) - buffer.put((_SetStateIsDone(), None, None)) + if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + (_SetStateIsDone(), None, None), buffer, should_stop.is_set + ): + continue if parent_exhausted: # Avoid busy-waiting when parent iterator is exhausted due to an # error. Wait until set_state_event or should_stop is set. @@ -226,7 +219,9 @@ def _put_dataset_elements_in_buffer( try: element = it.__next__() except Exception as e: # pylint: disable=broad-except - buffer.put((None, None, e)) + multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + (None, None, e), buffer, should_stop.is_set + ) parent_exhausted = True continue element = shared_memory_array.copy_to_shm(element, min_size=min_shm_size) @@ -234,9 +229,22 @@ def _put_dataset_elements_in_buffer( # __next__ method. if not it._stats._config.is_prefetch: # pylint: disable=protected-access it._stats.record_bytes_produced(element) # pylint: disable=protected-access - buffer.put((element, it.get_state(), None)) + if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + (element, it.get_state(), None), buffer, should_stop.is_set + ): + # We failed to put the element into the output queue because the + # should_stop event was set. The element may contain a shared memory + # block reference that has to be cleaned up. + shared_memory_array.unlink_shm(element) except Exception as e: # pylint: disable=broad-except - buffer.put((None, None, e)) + _clear_queue_and_maybe_unlink_shm(buffer) + _clear_queue_and_maybe_unlink_shm(set_state_queue) + multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + (None, None, e), buffer, should_stop.is_set + ) + return + _clear_queue_and_maybe_unlink_shm(buffer) + _clear_queue_and_maybe_unlink_shm(set_state_queue) class _SetStateIsDone: @@ -275,7 +283,6 @@ def __init__( self._stats_in_queue = self._process_ctx.Queue(maxsize=5) self._start_profiling_event = self._process_ctx.Event() self._stop_profiling_event = self._process_ctx.Event() - self._iterations_to_skip = 0 self._set_state_count = 0 self._exhausted = False self._prefetch_ds_iter = None @@ -392,16 +399,12 @@ def __next__(self): break elif element is not None: # Unlink shared memory for the discarded element. - shared_memory_array.open_from_shm(element) + shared_memory_array.unlink_shm(element) if err is not None: self._stop_prefetch() self._exhausted = True raise err - if state is None: - self._iterations_to_skip += 1 - else: - self._iterations_to_skip = 0 - self._state = state + self._state = state with self._stats.record_self_time(offset_ns=timer.value()): element = self._stats.record_bytes_produced(element) return shared_memory_array.open_from_shm(element) @@ -411,21 +414,9 @@ def close(self): self._closed = True self._stop_prefetch() - def _clear_buffer(self): - while True: - try: - element, _, _ = self._buffer.get_nowait() - if element is not None and not isinstance(element, _SetStateIsDone): - shared_memory_array.open_from_shm(element) - except queue.Empty: - return - def _clear_set_state_queue(self): - try: - self._set_state_queue.get_nowait() + if _clear_queue_and_maybe_unlink_shm(self._set_state_queue): self._set_state_count -= 1 - except queue.Empty: - return def _stop_prefetch(self): """Stops the prefetching process if it's currently running.""" @@ -433,35 +424,30 @@ def _stop_prefetch(self): return self._prefetch_should_stop.set() - # Remove entries from the buffer to unblock the producer, so that it checks - # producer_running.is_set() and exits. - self._clear_buffer() - self._prefetch_process.join(_PROCESS_KILL_TIMEOUT_S) + + # Not joining here will cause the children to be zombie after they finish. + # Need to join or call active_children. + self._prefetch_process.join(timeout=_PROCESS_KILL_TIMEOUT_S) + + # In case all our attempts to terminate the system fails, we forcefully + # kill the child processes. if self._prefetch_process.is_alive(): self._prefetch_process.kill() self._prefetch_process = None - # Clear the buffer again in case the prefetch loop added more elements on - # exit. - self._clear_buffer() + _clear_queue_and_maybe_unlink_shm(self._buffer) self._clear_set_state_queue() self._set_state_count = 0 def get_state(self) -> StateT: if self._state is None: - worker_state = self._iter_parent.__iter__().get_state() - else: - worker_state = self._state - return { - _WORKER_STATE: worker_state, - _ITERATIONS_TO_SKIP: self._iterations_to_skip, - } + return self._iter_parent.__iter__().get_state() + return self._state def set_state(self, state: StateT): - self._state = state[_WORKER_STATE] - self._iterations_to_skip = state[_ITERATIONS_TO_SKIP] + self._state = state # Remove any pending set_state calls. self._clear_set_state_queue() - self._set_state_queue.put((self._state, self._iterations_to_skip)) + self._set_state_queue.put(self._state) # Signal the prefetch process to start processing set_state calls. self._set_state_event.set() # Increment the number of _SetStateIsDone that need to be skipped to @@ -473,6 +459,35 @@ def __str__(self) -> str: return f"ProcessPrefetchDatasetIterator(buffer_size={self._buffer_size})" +class _LazyWorkerSliceIterDataset(dataset.IterDataset[T]): + """Applies slice to the parent dataset in the worker process.""" + + def __init__( + self, + parent: dataset.IterDataset[T], + sl: slice, + sequential_slice: bool, + ): + super().__init__(parent) + self._slice = sl + self._sequential_slice = sequential_slice + + def __iter__(self) -> dataset.DatasetIterator[T]: + if not _is_in_worker_process: + return self._parent.__iter__() + prefetch._set_slice_iter_dataset( + self._parent, self._slice, self._sequential_slice + ) + return self._parent.__iter__() + + @property + def _element_spec(self) -> Any: + return dataset.get_element_spec(self._parent) + + def __str__(self) -> str: + return f"_LazyWorkerSliceIterDataset(slice={self._slice})" + + def multiprocess_prefetch( ds: dataset.IterDataset[T], num_workers: int = 0, @@ -507,10 +522,12 @@ def multiprocess_prefetch( if num_workers == 1: worker_ds = ds else: - worker_ds = copy.deepcopy(ds) - prefetch._set_slice_iter_dataset( # pylint: disable=protected-access - worker_ds, slice(i, None, num_workers), sequential_slice + worker_ds = _LazyWorkerSliceIterDataset( + ds, + slice(i, None, num_workers), + sequential_slice, ) + worker_ds = prefetch._MpContextIterDataset( # pylint: disable=protected-access worker_ds, base.MultiprocessingContext( diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index 63f99fac9..164f86a02 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -28,7 +28,6 @@ from grain._src.python import options from grain._src.python.dataset import base from grain._src.python.dataset import dataset -from grain._src.python.dataset.transformations import prefetch from grain._src.python.dataset.transformations import process_prefetch import numpy as np @@ -109,6 +108,20 @@ def test_checkpoint(self, warm_start: bool): value = next(ds_iter) self.assertEqual(value, values_without_interruption[i]) + def test_set_state_after_stop_iteration(self): + ds = process_prefetch.ProcessPrefetchIterDataset( + self.ds, + buffer_size=5, + ) + ds_iter = ds.__iter__() + state = ds_iter.get_state() + ds_iter.start_prefetch() + lst1 = list(ds_iter) + self.assertEmpty(list(ds_iter)) + ds_iter.set_state(state) + lst2 = list(ds_iter) + self.assertSequenceEqual(lst1, lst2) + def test_set_state_does_not_restart_process(self): ds = process_prefetch.ProcessPrefetchIterDataset( self.ds.map(lambda i: (i, os.getpid())), @@ -265,23 +278,6 @@ def __getstate__(self): ): list(ds) - def test_fails_with_nested_prefetch(self): - ds1 = process_prefetch.ProcessPrefetchIterDataset(self.ds, buffer_size=1) - with self.assertRaisesRegex( - ValueError, - 'Nesting prefetching with processes is not allowed', - ): - process_prefetch.ProcessPrefetchIterDataset(ds1, buffer_size=1) - - ds2 = prefetch.MultiprocessPrefetchIterDataset( - self.ds, options.MultiprocessingOptions(num_workers=1) - ) - with self.assertRaisesRegex( - ValueError, - 'Nesting prefetching with processes is not allowed', - ): - process_prefetch.ProcessPrefetchIterDataset(ds2, buffer_size=1) - def test_reports_worker_crash(self): def failing_transform(element): del element @@ -317,6 +313,22 @@ def test_options_after_prefetch(self): with self.assertRaises(Exception): list(ds) + def test_doesnt_hang_on_process_kill(self): + def map_fn(x): + if x == 3: + while True: + pass + return x + + ds = dataset.MapDataset.source([1, 2, 3]) + ds = ds.map(map_fn) + ds = ds.to_iter_dataset() + ds = process_prefetch.ProcessPrefetchIterDataset(ds, buffer_size=1) + ds_iter = ds.__iter__() + self.assertEqual(next(ds_iter), 1) + self.assertEqual(next(ds_iter), 2) + ds_iter.close() + @dataclasses.dataclass(frozen=True) class FilterAllElements(transforms.Filter): @@ -325,6 +337,19 @@ def filter(self, element: int): return False +class RandomTripletSource: + + def __len__(self) -> int: + return 100_000 + + def __getitem__(self, record_key: int): + return { + 'data': ( + np.random.uniform(size=(3, 224, 224, 3)).astype(dtype=np.float32) + ) + } + + class RepeatedIntSourceIterDataset(dataset.IterDataset[int]): def __iter__(self) -> dataset.DatasetIterator[int]: @@ -469,10 +494,11 @@ def test_fails_with_iter_source_multiple_workers(self): ValueError, 'Cannot slice `IterDataset` source.', ): - process_prefetch.multiprocess_prefetch( + ds = process_prefetch.multiprocess_prefetch( RepeatedIntSourceIterDataset().map(lambda x: x + 1), num_workers=2, ) + list(ds) def test_propagates_transform_error(self): error_msg = 'I shall fail!' diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py deleted file mode 100644 index 0bcf381e3..000000000 --- a/grain/_src/python/grain_pool.py +++ /dev/null @@ -1,834 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""This module provides a way to distribute processing across multiple workers. - -In the context of Grain we use the term "process" similar to JAX, where usually -each machine runs one Python process (identified by `jax.process_index()`). -In Grain each "process" can create additional Python child processes that we -call "workers". - -GrainPool manages a set of Python processes. It's similar to -`multiprocessing.Pool` but optimises communication between the processes to -enable high throughput data pipelines. -The GrainPool works as follows: -* Parent process launches a set of "num_workers" child processes. -* Each child process produces elements by reading data and transforming it. The - resulting elements are added to a queue (each child process has its queue). -* Parent process reads data from the children queues in a strict round-robin - fashion. - -Shutdown logic considerations: -* Child processes are launched as Daemon processes. In case of (unexpected) - parent termination, child processes will be terminated by OS. -* System uses a multiprocessing event ("termination_event") for termination. - Parent and child processes continuously check if the "termination_event" and - if set, they break from what they are doing. -* We never block indefinitely when calling get() or put() on a queue. This - ensures parent and child processes continue to check the termination_event. - -MultiProcessIterator wraps GrainPool adding lifecycle management, checkpointing -support and multithreaded elements read. -""" - -from __future__ import annotations - -from collections.abc import Iterator -import cProfile -import dataclasses -from multiprocessing import context -from multiprocessing import pool -from multiprocessing import queues -from multiprocessing import synchronize -import pstats -import queue -import sys -import threading -import traceback -from typing import Any, Callable, Protocol, Type, TypeVar, Union, runtime_checkable - -from absl import flags -from absl import logging -import cloudpickle -from grain._src.core import monitoring as grain_monitoring -from grain._src.core import parallel -from grain._src.core import tree_lib -from grain._src.core.config import config -import multiprocessing as mp -from grain._src.python import grain_logging -from grain._src.python import multiprocessing_common -from grain._src.python import record -from grain._src.python import shared_memory_array -from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member - - -T = TypeVar("T") - -# Maximum number of threads for starting and stopping processes. -_PROCESS_MANAGEMENT_MAX_THREADS = 64 -_PROCESS_JOIN_TIMEOUT = 25 -_QUEUE_WAIT_TIMEOUT = 1 -# Input queues contain small structures (record metadata), thus they are safe -# to have a big size. -_INPUT_QUEUE_MAX_SIZE = 10000 - - -@dataclasses.dataclass -class _ProcessingComplete: - """Indicates child process finished processing.""" - - -_PROCESSING_COMPLETE = _ProcessingComplete() - - -@dataclasses.dataclass(slots=True, frozen=True) -class GrainPoolElement: - """Wrapper for output records emited by Grain Pool.""" - - record: Any - worker_index: Any - - -@dataclasses.dataclass(slots=True, frozen=True) -class RemoteWorkerError: - """Grain worker exception that can be pickled and sent over a queue.""" - error_cls: Type[Exception] - error: str - worker_index: int - - @property - def original_error(self) -> Exception: - msg = ( - f"Grain worker {self.worker_index} failed with the following" - f" error:\n\n{self.error}" - ) - # Custom exception classes can have different c'tor arguments. - try: - return self.error_cls(msg) - except Exception: # pylint: disable=broad-except - return RuntimeError(msg) - - -def _print_profile(preamble: str, profile: cProfile.Profile): - """Prints output of cProfile, sorted by cumulative time.""" - print(preamble) - stats = pstats.Stats(profile).sort_stats(pstats.SortKey.CUMULATIVE) - stats.print_stats() - - -@runtime_checkable -class GetElementProducerFn(Protocol[T]): - """A callable class able to generate elements with serialization support.""" - - def __call__( - self, - *, - worker_index: int, - worker_count: int, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - stats_out_queue: queues.Queue | None = None, - ) -> Iterator[T]: - """Returns a generator of elements.""" - - def serialize(self) -> bytes: - """Serializes itself and the result will be used by `deserialize`. - - If a class inherits from this class, it should make sure `deserialize` - is compatible with this `serialize` function. - i.e. `GetElementProducerFn.deserialize(obj.serialize())` should return the - same object as `obj: GetElementProducerFn`. - - Returns: - a serialized string of myself. - """ - return cloudpickle.dumps(self) - - @classmethod - def deserialize(cls, serialized: bytes) -> GetElementProducerFn[T]: - """Deserializes the result from `serialize`.""" - del cls - - obj = cloudpickle.loads(serialized) - if not isinstance(obj, GetElementProducerFn): - raise ValueError( - "`serialized` should be deserialized into `GetElementProducerFn`." - ) - - return obj - - -def parse_debug_flags(debug_flags: dict[str, Any]): - """Parses debug flags.""" - - flags.FLAGS["grain_py_debug_mode"].present = True - flags.FLAGS["grain_py_dataset_visualization_output_dir"].present = True - config.update("py_debug_mode", debug_flags["grain_py_debug_mode"]) - config.update( - "py_dataset_visualization_output_dir", - debug_flags["grain_py_dataset_visualization_output_dir"], - ) - - -def _initialize_and_get_element_producer( - args_queue: queues.Queue, - *, - debug_flags: dict[str, Any], - worker_index: int, - worker_count: int, - start_profiling_event: synchronize.Event, - stop_profiling_event: synchronize.Event, - stats_out_queue: queues.Queue, -) -> Iterator[Any]: - """Unpickles the element producer from the args queue and closes the queue.""" - ( - serialized_flag_parse_fn, - serialized_init_fns, - serialized_element_producer_fn, - ) = args_queue.get() - flag_parse_fn: Callable[[Any], None] = cloudpickle.loads( - serialized_flag_parse_fn - ) - flag_parse_fn(debug_flags) - init_fns: list[Callable[[int, int], None]] = cloudpickle.loads( - serialized_init_fns - ) - for init_fn in init_fns: - init_fn(worker_index, worker_count) - element_producer_fn: GetElementProducerFn[Any] = ( - GetElementProducerFn.deserialize(serialized_element_producer_fn) - ) - - element_producer = element_producer_fn( - worker_index=worker_index, - worker_count=worker_count, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - stats_out_queue=stats_out_queue, - ) - # args_queue has only a single argument and thus can be safely closed. - args_queue.close() - return element_producer - - -def _worker_loop( - *, - args_queue: queues.Queue, - errors_queue: queues.Queue, - output_queue: queues.Queue, - termination_event: synchronize.Event, - start_profiling_event: synchronize.Event, - stop_profiling_event: synchronize.Event, - worker_index: int, - worker_count: int, - enable_profiling: bool, - debug_flags: dict[str, Any], - stats_out_queue: queues.Queue, -): - """Code to be run on each child process.""" - out_of_elements = False - try: - worker_index_suffix = "" if worker_count == 1 else f" {worker_index}" - grain_logging.set_process_identifier_prefix( - f"PyGrain Worker{worker_index_suffix}" - ) - logging.info("Starting work.") - element_producer = _initialize_and_get_element_producer( - args_queue, - debug_flags=debug_flags, - worker_index=worker_index, - worker_count=worker_count, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - stats_out_queue=stats_out_queue, - ) - profiling_enabled = enable_profiling and worker_index == 0 - if profiling_enabled: - profile = cProfile.Profile() - profile.enable() - # If termination event is set, we terminate and discard remaining elements. - while not termination_event.is_set(): - try: - next_element = next(element_producer) - if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types - next_element, output_queue, termination_event.is_set - ): - # We failed to put the element into the output queue because the - # termination event was set. The element may contain a shared memory - # block reference that has to be cleaned up. - _unlink_shm_in_structure(next_element) - except StopIteration: - out_of_elements = True - multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types - _ProcessingComplete(), output_queue, termination_event.is_set - ) - break - if profiling_enabled: - profile.disable() - _print_profile(f"PROFILE OF PROCESS WITH IDX {worker_index}.", profile) - - except Exception as e: # pylint: disable=broad-except - logging.exception( - "Error occurred in child process with worker_index: %i", worker_index - ) - remote_error = RemoteWorkerError( - error_cls=e.__class__, - error="".join( - traceback.format_exception(e.__class__, e, e.__traceback__) - ), - worker_index=worker_index, - ) - try: - errors_queue.put(remote_error, timeout=_QUEUE_WAIT_TIMEOUT) - except queue.Full: - logging.error("Couldn't send exception from child process. Queue full!") - - logging.info( - "Setting termination event in process with worker_index: %i", - worker_index, - ) - termination_event.set() - - if termination_event.is_set(): - if not out_of_elements: - # Since the termination event is set the consumer will not get any more - # elements from the output queue. The elements may contain reference to - # shared memory blocks that have to be cleaned up. - while not output_queue.empty(): - _unlink_shm_in_structure(output_queue.get_nowait()) - # When adding elements to the queue, element is put in a buffer and a - # background thread flushes the elements through the pipe. The process that - # writes to the queue joins that thread automatically on exit. We call - # cancel_join_thread when system terminates to prevent deadlocks. - output_queue.cancel_join_thread() - output_queue.close() - logging.info("Process %i exiting.", worker_index) - - -def _unlink_shm_if_metadata(obj: Any): - if isinstance(obj, shared_memory_array.SharedMemoryArrayMetadata): - obj.close_and_unlink_shm() - - -def _unlink_shm_in_structure(structure: Any): - if isinstance(structure, record.Record): - _unlink_shm_in_structure(structure.data) - else: - tree_lib.map_structure(_unlink_shm_if_metadata, structure) - - -class GrainPool(Iterator[T]): - """Pool to parallelize processing of Grain pipelines among a set of processes.""" - - def __init__( - self, - ctx: context.BaseContext, - *, - get_element_producer_fn: GetElementProducerFn[T], - worker_index_to_start_reading: int = 0, - termination_event: threading.Event | None = None, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - options: MultiprocessingOptions, - worker_init_fn: Callable[[int, int], None] | None = None, - stats_in_queues: tuple[queues.Queue, ...] | None = None, - ): - """Initialise a Grain Pool. - - Args: - ctx: Context to make multiprocessing primitives work. - get_element_producer_fn: Callable that returns an iterator over the - elements given the process index and process count. - worker_index_to_start_reading: index of worker to start reading output - batches from (needed for checkpointing support). - termination_event: Setting this event will terminate the pool. Otherwise, - the pool will terminate when either one of the workers failed or when - all workers are done processing data. GrainPool will not set this event. - start_profiling_event: Event to start prism profiling. - stop_profiling_event: Event to stop prism profiling. - options: Options for multiprocessing. See MultiprocessingOptions. - worker_init_fn: Function to run in each worker process before the element - producer. The function takes two arguments: the current worker index and - the total worker count. - stats_in_queues: Queue to propagate execution summary from child processes - to the parent. - """ - self.num_processes = options.num_workers - logging.info("Grain pool will use %i processes.", self.num_processes) - self.worker_args_queues = [] - self.worker_output_queues = [] - self.processes = [] - # Reader termination should always result in worker termination. However, - # worker termination should not shut down the reader: workers are terminated - # when they finished processing data, but the reader may still need to read - # the remaining output from the shared queues. That is why we use two - # separate events. - self._reader_termination_event = termination_event or threading.Event() - self._workers_termination_event = ctx.Event() - self._worker_init_fn = worker_init_fn - self.completed_processes = set() - # Queue to propagate errors from child processes to the parent. Note that - # this queue is shared by all child processes. - self.worker_error_queue = ctx.Queue(self.num_processes) - self.stats_in_queues = stats_in_queues - - try: - get_element_producer_fn = get_element_producer_fn.serialize() - except Exception as e: - if sys.version_info >= (3, 11): - e.add_note( - "\nCould not serialize transformation function passed to Grain " - "workers. This likely means that your data source, sampler or one " - "of your transformations cannot be serialized. Please make sure " - "that the objects work with cloudpickle.dumps()." - ) - raise e - - for worker_index in range(self.num_processes): - worker_args_queue = ctx.Queue(1) - worker_output_queue = ctx.Queue(options.per_worker_buffer_size) - process_kwargs = dict( - args_queue=worker_args_queue, - errors_queue=self.worker_error_queue, - output_queue=worker_output_queue, - stats_out_queue=( - self.stats_in_queues[worker_index] - if self.stats_in_queues - else None - ), - termination_event=self._workers_termination_event, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - worker_index=worker_index, - worker_count=options.num_workers, - enable_profiling=options.enable_profiling, - debug_flags=dict( - grain_py_debug_mode=config.get_or_default("py_debug_mode"), - grain_py_dataset_visualization_output_dir=( - config.get_or_default("py_dataset_visualization_output_dir") - ), - ), - ) - # The process kwargs must all be pickable and will be unpickle before - # absl.app.run() is called. We send arguments via a queue to ensure that - # they are unpickled after absl.app.run() was called in the child - # processes. - worker_init_fns = [self._worker_init_fn] if self._worker_init_fn else [] - parse_debug_flags_fn = parse_debug_flags - worker_init_fns = cloudpickle.dumps(worker_init_fns) - parse_debug_flags_fn = cloudpickle.dumps(parse_debug_flags_fn) - worker_args_queue.put( - (parse_debug_flags_fn, worker_init_fns, get_element_producer_fn) - ) - process = ctx.Process( # pytype: disable=attribute-error # re-none - target=_worker_loop, kwargs=process_kwargs, daemon=True - ) - self.worker_args_queues.append(worker_args_queue) - self.worker_output_queues.append(worker_output_queue) - self.processes.append(process) - - logging.info("Grain pool will start child processes.") - parallel.run_in_parallel( - function=lambda child_process: child_process.start(), - list_of_kwargs_to_function=[ - {"child_process": p} for p in self.processes - ], - num_workers=min(_PROCESS_MANAGEMENT_MAX_THREADS, self.num_processes), - ) - logging.info("Grain pool started all child processes.") - self._next_worker_index = worker_index_to_start_reading - - def __iter__(self) -> GrainPool: - return self - - def _process_failed(self, worker_index: int) -> bool: - exit_code = self.processes[worker_index].exitcode - return exit_code is not None and exit_code != 0 - - def _processing_completed(self) -> bool: - return all(p.exitcode == 0 for p in self.processes) - - def _update_next_worker_index(self) -> None: - self._next_worker_index = (self._next_worker_index + 1) % self.num_processes - - def __next__(self) -> GrainPoolElement: - processing_failed = False - while ( - not self._workers_termination_event.is_set() - and len(self.completed_processes) < self.num_processes - ): - # If the reader was shut down, e.g. due to iterator deletion, we should - # shut down the workers. - if self._reader_termination_event.is_set(): - self._shutdown() - # Since the reader is shut down it doesn't matter what we return here. - # We should not raise an exception because it is common to iterate over - # infinite datasets and delete the iterator before processing is - # complete. - return GrainPoolElement( - "Grain worker pool reader was terminated, shutting down workers.", - -1, - ) - if self._next_worker_index in self.completed_processes: - self._update_next_worker_index() - continue - try: - element_worker_index = self._next_worker_index - element = self.worker_output_queues[self._next_worker_index].get( - timeout=_QUEUE_WAIT_TIMEOUT - ) - logging.debug("Read element from process: %s", self._next_worker_index) - if element == _PROCESSING_COMPLETE: - logging.info( - "Processing complete for process with worker_index %i", - self._next_worker_index, - ) - self.completed_processes.add(self._next_worker_index) - self._update_next_worker_index() - else: - self._update_next_worker_index() - return GrainPoolElement(element, element_worker_index) - except queue.Empty: - logging.debug("Got no element from process %s", self._next_worker_index) - if self._process_failed(self._next_worker_index): - processing_failed = True - logging.info( - "Process with idx %i Failed (Exitcode: %s).", - self._next_worker_index, - self.processes[self._next_worker_index].exitcode, - ) - break - - if processing_failed or self._workers_termination_event.is_set(): - logging.error("Processing Failed. Shutting down.") - self._shutdown() - - try: - remote_error = self.worker_error_queue.get(timeout=_QUEUE_WAIT_TIMEOUT) - raise remote_error.original_error - except queue.Empty: - # Worker did not report any error. This means that either an exception - # was raised outside of the worker loop (e.g. during flag parsing) or - # the worker process was forcefully terminated. Unfortunately, there is - # no debugging info available in the main process at this point apart - # from the exit code. The crash logs, however, should've been produced. - raise RuntimeError( - f"Grain worker process {self._next_worker_index} was terminated" - " unexpectedly with exit code " - f"{self.processes[self._next_worker_index].exitcode}. Search the " - "logs above for the source of the crash." - ) from None - - # Processing successfully completed. - raise StopIteration - - def __del__(self): - self._shutdown() - - def __enter__(self) -> GrainPool: - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - logging.info("Grain pool is exiting.") - self._shutdown() - - def _shutdown(self) -> None: - """Gracefully shutdown the multiprocessing system.""" - logging.info("Shutting down multiprocessing system.") - try: - self._workers_termination_event.set() - # There is a chance that shutdown was triggered before the worker - # processes fully initialized and read from the arg queues. The arg - # queues will block the main process until their elements are flushed - # through the pipes, which will never happen since the workers were shut - # down. Here we avoid blocking the main process, see - # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread - for q in self.worker_args_queues: - q.cancel_join_thread() - q.close() - # Not joining here will cause the children to be zombie after they finish. - # Need to join or call active_children. - for process in self.processes: - process.join(timeout=_PROCESS_JOIN_TIMEOUT) - finally: - for process in self.processes: - # In case all our attempts to terminate the system fails, we forcefully - # kill the child processes. - if process.is_alive(): - logging.info("Killing worker process with pid %i", process.pid) - process.kill() - - -@dataclasses.dataclass(slots=True, frozen=True) -class _ReaderQueueElement: - """Element to be added to the reader queue.""" - - async_result: pool.AsyncResult[Any] - # index of worker producing the element in [0, worker_count] - worker_index: int - - -@dataclasses.dataclass(frozen=True) -class _GrainPoolProcessingComplete: - """Indicates processing of grain pool is complete.""" - - -_GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete() -_QueueElement = Union[ - _ReaderQueueElement, _GrainPoolProcessingComplete, Exception -] - - -def _open_shared_memory_for_leaf(element: Any) -> Any: - if isinstance(element, shared_memory_array.SharedMemoryArrayMetadata): - element = shared_memory_array.SharedMemoryArray.from_metadata(element) - element.unlink_on_del() - return element - - -def _open_shared_memory_for_structure(structure: Any) -> Any: - if isinstance(structure, record.Record): - structure.data = tree_lib.map_structure( - _open_shared_memory_for_leaf, structure.data - ) - return structure - return tree_lib.map_structure(_open_shared_memory_for_leaf, structure) - - -def _process_elements_in_grain_pool( - *, - get_element_producer_fn: GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - reader_queue: queue.Queue[_QueueElement], - thread_pool: pool.ThreadPool, - termination_event: threading.Event, - start_profiling_event: synchronize.Event | None, - stop_profiling_event: synchronize.Event | None, - worker_index_to_start_reading: int, - worker_init_fn: Callable[[int, int], None] | None, - stats_in_queues: tuple[queues.Queue, ...] | None, -) -> None: - """Processes elements in grain worker pool asynchronously.""" - - def read_thread_should_stop(): - return termination_event.is_set() or not threading.main_thread().is_alive() - - ctx = mp.get_context("spawn") - - try: - with GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - worker_index_to_start_reading=worker_index_to_start_reading, - termination_event=termination_event, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - options=multiprocessing_options, - worker_init_fn=worker_init_fn, - stats_in_queues=stats_in_queues, - ) as g_pool: - for element in g_pool: - if read_thread_should_stop(): - break - # Note: We use a thread pool for opening the shared memory because - # in some cases the calls to `shm_open` can actually become the - # bottleneck for a single thread. - async_result = thread_pool.apply_async( - _open_shared_memory_for_structure, - args=(element.record,), - ) - multiprocessing_common.add_element_to_queue( - _ReaderQueueElement( - async_result, - element.worker_index, - ), - reader_queue, - read_thread_should_stop, - ) - # This exception could arise from user-provide code. Propagating it to - # the main thread to re-raise it as is. - except Exception as e: # pylint: disable=broad-except - multiprocessing_common.add_element_to_queue( - e, reader_queue, read_thread_should_stop - ) - return - multiprocessing_common.add_element_to_queue( - _GrainPoolProcessingComplete(), - reader_queue, - read_thread_should_stop, - ) - - -class MultiProcessIteratorInvalidStateError(Exception): - """Raised when iterator is an invalid state and can't be iterated on.""" - - -class MultiProcessIterator(Iterator[T]): - """Runs iterators returned by `get_element_producer_fn` in child processes. - - Note: MultiProcessIterator implements the Context Manager protocol to clean - resources. As such, it must be used within a "with" statement. - - Wraps `GrainPool` adding lifecycle management, multithreaded elements read and - recording the last worker index useful for checkpointing. - """ - - def __init__( - self, - get_element_producer_fn: GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - worker_index_to_start_reading: int, - worker_init_fn: Callable[[int, int], None] | None = None, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - stats_in_queues: tuple[queues.Queue, ...] | None = None, - ): - """Initializes MultiProcessIterator. - - Args: - get_element_producer_fn: factory making record iterators for each child - process. - multiprocessing_options: options for distributing the record iterators. - worker_index_to_start_reading: Index of the next worker to read from. This - is useful for recovering from a checkpoint. - worker_init_fn: Function to run in each worker process before the element - producer. The function takes two arguments: the current worker index and - the total worker count. - start_profiling_event: Event to start prism profiling. - stop_profiling_event: Event to stop prism profiling. - stats_in_queues: Queues to send execution summaries from worker processes - to the main process. - """ - self._get_element_producer_fn = get_element_producer_fn - self._multiprocessing_options = multiprocessing_options - self._last_worker_index = worker_index_to_start_reading - 1 - self._worker_init_fn = worker_init_fn - self._reader_queue = None - self._reader_thread_pool = None - self._termination_event = None - self._reader_thread = None - self._stats_in_queues = stats_in_queues - self._start_profiling_event = start_profiling_event - self._stop_profiling_event = stop_profiling_event - - def __del__(self): - if self._reader_thread: - logging.info("Destroying multiprocess iterator.") - self.stop_prefetch() - - def start_prefetch(self) -> None: - """Starts the prefetching threads.""" - - if self._reader_thread: - return - - max_buffered_elements = ( - self._multiprocessing_options.num_workers - * self._multiprocessing_options.per_worker_buffer_size - ) - self._reader_queue = queue.Queue(maxsize=max_buffered_elements) - self._reader_thread_pool = pool.ThreadPool(max_buffered_elements) - self._termination_event = threading.Event() - self._reader_thread = threading.Thread( - target=_process_elements_in_grain_pool, - kwargs=dict( - get_element_producer_fn=self._get_element_producer_fn, - multiprocessing_options=self._multiprocessing_options, - reader_queue=self._reader_queue, - thread_pool=self._reader_thread_pool, - termination_event=self._termination_event, - start_profiling_event=self._start_profiling_event, - stop_profiling_event=self._stop_profiling_event, - worker_index_to_start_reading=self._last_worker_index + 1, - worker_init_fn=self._worker_init_fn, - stats_in_queues=self._stats_in_queues, - ), - ) - self._reader_thread.start() - shared_memory_array.SharedMemoryArray.enable_async_del( - self._multiprocessing_options.num_workers - ) - - def stop_prefetch(self) -> None: - """Cleans up prefetching threads.""" - - if not self._reader_thread: - return - - # pytype: disable=attribute-error - self._termination_event.set() - self._reader_thread_pool.close() - self._reader_thread.join() - self._reader_thread_pool.join() - # pytype: enable=attribute-error - self._termination_event = None - self._reader_thread_pool = None - self._reader_thread = None - self._reader_queue = None - - def __enter__(self): - self.start_prefetch() - return self - - def __exit__(self, exc_type, exc_value, tb): - self.stop_prefetch() - - def _can_iterate(self): - """Checks whether the object is in a state where it can be iterated on.""" - return ( - self._reader_queue is not None - and self._termination_event is not None - and self._reader_thread_pool is not None - and self._reader_thread is not None - ) - - def __iter__(self): - if not self._can_iterate(): - raise MultiProcessIteratorInvalidStateError( - "MultiProcessIterator is in an invalid state. Note that" - " MultiProcessIterator should be used with a 'with' statement." - ) - return self - - def get_last_worker_index(self): - return self._last_worker_index - - def __next__(self): - if not self._can_iterate(): - raise MultiProcessIteratorInvalidStateError( - "MultiProcessIterator is in an invalid state. Note that" - " MultiProcessIterator should be used with a 'with' statement." - ) - element = multiprocessing_common.get_element_from_queue( - self._reader_queue, self._termination_event.is_set # pytype: disable=attribute-error - ) - if isinstance(element, Exception): - raise element - if ( - element == _GRAIN_POOL_PROCESSING_COMPLETE - or element == multiprocessing_common.SYSTEM_TERMINATED - ): - raise StopIteration - - if not isinstance(element, _ReaderQueueElement): - raise ValueError( - f"Got invalid element type from GrainPool: {type(element)}" - ) - - result = multiprocessing_common.get_async_result( - element.async_result, self._termination_event.is_set - ) - if isinstance(result, multiprocessing_common._SystemTerminated): # pylint: disable=protected-access - raise StopIteration - self._last_worker_index = element.worker_index - return result diff --git a/grain/_src/python/grain_pool_test.py b/grain/_src/python/grain_pool_test.py deleted file mode 100644 index 5aa87795b..000000000 --- a/grain/_src/python/grain_pool_test.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for GrainPool.""" - -from collections.abc import Iterator -import multiprocessing -import os -import platform -import signal -import sys -from typing import Any -from absl import flags -from absl.testing import absltest -from absl.testing import parameterized -from grain._src.core import config -from grain._src.core import monitoring as grain_monitoring -import multiprocessing as mp -from grain._src.python import data_sources -from grain._src.python import grain_pool as gp -from grain._src.python import record -from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member - - -class GrainPoolTest(absltest.TestCase): - - def _join_and_assert_process_exitcode(self, process: multiprocessing.Process): - # The process can be potentially terminated forcibly and needs a moment to - # finalize and update the exitcode. - process.join(timeout=gp._PROCESS_JOIN_TIMEOUT) - self.assertIn(process.exitcode, {0, -signal.SIGTERM}) - - def test_pool_with_flags_not_parsed(self): - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - # unparse the flags explicitly - flags.FLAGS.unparse_flags() - - _ = gp.GrainPool( - ctx=mp.get_context("spawn"), - get_element_producer_fn=get_element_producer_fn, - options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), - ) - - def test_pool_equal_split_in_memory_data_source(self): - in_memory_ds = data_sources.SharedMemoryDataSource(range(12)) - - # 12 elements in the `in_memory_ds` are divided - # equally among 4 processes. - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 12, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - output_elements = [] - with gp.GrainPool( - ctx=mp.get_context("spawn"), - get_element_producer_fn=get_element_producer_fn, - options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), - ) as grain_pool: - for element in grain_pool: - output_elements.append(element) - # turn each element in `in_memory_ds` to their negatives. - in_memory_ds[element.record] = -in_memory_ds[element.record] - - self.assertEqual( - output_elements, [gp.GrainPoolElement(x, x % 4) for x in range(12)] - ) - - self.assertEqual(list(iter(in_memory_ds)), [-x for x in range(12)]) - - def test_pool_equal_split(self): - ctx = mp.get_context("spawn") - - # 16 elements divide equally among 4 processes - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 16, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - output_elements = [] - with gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) as grain_pool: - for element in grain_pool: - output_elements.append(element) - expected_elements = list( - map( - lambda x: gp.GrainPoolElement(x, x % options.num_workers), range(16) - ) - ) - self.assertEqual(expected_elements, output_elements) - # Make sure num_processes processes were launched. - self.assertLen(grain_pool.processes, options.num_workers) - # Make sure all child processes exited successfully. - for child_process in grain_pool.processes: - self._join_and_assert_process_exitcode(child_process) - - def test_pool_non_equal_split(self): - ctx = mp.get_context("spawn") - - # 14 elements do not divide equally among 4 processes - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - output_elements = [] - with gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) as grain_pool: - for element in grain_pool: - output_elements.append(element) - expected_elements = list( - map( - lambda x: gp.GrainPoolElement(x, x % options.num_workers), range(14) - ) - ) - self.assertEqual(expected_elements, output_elements) - # Make sure all child processes exited successfully. - for child_process in grain_pool.processes: - self._join_and_assert_process_exitcode(child_process) - - @absltest.skipIf( - platform.system() == "Windows", "SIGKILL signal not available on Windows." - ) - def test_pool_kill_child(self): - ctx = mp.get_context("spawn") - - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - with gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) as grain_pool: - child_pid = grain_pool.processes[0].pid - os.kill(child_pid, signal.SIGKILL) - - self.assertEqual( - grain_pool.processes[0].exitcode, -1 * signal.SIGKILL.value - ) - for child_process in grain_pool.processes[1:]: - self._join_and_assert_process_exitcode(child_process) - - def test_pool_object_deletion(self): - ctx = mp.get_context("spawn") - - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - - # Users should generally use the with statement, here we test if GrainPool - # was created without the "with statement", that object deletion would - # have child processes gracefully exited. - grain_pool = gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) - - child_processes = grain_pool.processes - grain_pool.__del__() - - for child_process in child_processes: - self._join_and_assert_process_exitcode(child_process) - - -def _make_uniform_element_producer_fn( - last_seen_index: int = -1, -) -> gp.GetElementProducerFn: - - class _RoundrobinElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self - yield from range(10)[last_seen_index + 1 + worker_index :: worker_count] - - return _RoundrobinElementProducerFn() - - -class RoundrobinRecordElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[record.Record[int]]: - del self - for i in range(5)[worker_index::worker_count]: - yield record.Record(record.RecordMetadata(i), i) - - -class NonUniformElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_count - for _ in range(worker_index * 3): - yield worker_index - - -class MultiProcessIteratorTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name="two_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=0, - expected=list(range(10)), - ), - dict( - testcase_name="five_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=5), - worker_index_to_start_reading=0, - expected=list(range(10)), - ), - dict( - testcase_name="from_checkpoint", - get_element_producer_fn=_make_uniform_element_producer_fn(5), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=1, - expected=[7, 6, 9, 8], - ), - dict( - testcase_name="non_uniform", - get_element_producer_fn=NonUniformElementProducerFn(), - multiprocessing_options=MultiprocessingOptions(num_workers=3), - worker_index_to_start_reading=0, - expected=[1, 2, 1, 2, 1, 2, 2, 2, 2], - ), - dict( - testcase_name="record_producer_fn", - get_element_producer_fn=RoundrobinRecordElementProducerFn(), - multiprocessing_options=MultiprocessingOptions(num_workers=3), - worker_index_to_start_reading=0, - expected=[ - record.Record(record.RecordMetadata(i), i) for i in range(5) - ], - ), - ) - def test_produces_correct_data( - self, - get_element_producer_fn: gp.GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - worker_index_to_start_reading: int, - expected: Any, - ): - with gp.MultiProcessIterator( - get_element_producer_fn, - multiprocessing_options, - worker_index_to_start_reading, - ) as iterator: - actual = list(iterator) - self.assertEqual(actual, expected) - - @parameterized.named_parameters( - dict( - testcase_name="two_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=1, - num_iters=5, - expected_last_worker_index=1, - ), - dict( - testcase_name="five_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=5), - worker_index_to_start_reading=0, - num_iters=7, - expected_last_worker_index=1, - ), - dict( - testcase_name="five_workers_incomplete_round", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=5), - worker_index_to_start_reading=0, - num_iters=3, - expected_last_worker_index=2, - ), - dict( - testcase_name="from_checkpoint", - get_element_producer_fn=_make_uniform_element_producer_fn(5), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=0, - num_iters=3, - expected_last_worker_index=0, - ), - dict( - testcase_name="non_uniform_record_producer_fn", - get_element_producer_fn=NonUniformElementProducerFn(), - multiprocessing_options=MultiprocessingOptions(num_workers=3), - worker_index_to_start_reading=0, - num_iters=6, - expected_last_worker_index=2, - ), - ) - def test_get_state( - self, - get_element_producer_fn: gp.GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - worker_index_to_start_reading: int, - num_iters: int, - expected_last_worker_index: int, - ): - with gp.MultiProcessIterator( - get_element_producer_fn, - multiprocessing_options, - worker_index_to_start_reading, - ) as iterator: - for _ in range(num_iters): - _ = next(iterator) - actual_last_worker_index = iterator.get_last_worker_index() - self.assertEqual(actual_last_worker_index, expected_last_worker_index) - - def test_fails_with_zero_workers(self): - with self.assertRaisesRegex( - ValueError, "Number of processes must be at least 1" - ): - with gp.MultiProcessIterator( - _make_uniform_element_producer_fn(), - MultiprocessingOptions(num_workers=0), - 0, - ) as iterator: - list(iterator) - - def test_propagates_error(self): - error_msg = "very unique error" - - class FailingGetElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_index, worker_count - raise ValueError(error_msg) - - failing_get_element_producer_fn = FailingGetElementProducerFn() - - with gp.MultiProcessIterator( - failing_get_element_producer_fn, - MultiprocessingOptions(num_workers=2), - 0, - ) as iterator: - with self.assertRaisesRegex(ValueError, error_msg): - list(iterator) - - def test_reports_worker_crash(self): - - class FailingGetElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_index, worker_count - sys.exit(12) - - failing_get_element_producer_fn = FailingGetElementProducerFn() - - with gp.MultiProcessIterator( - failing_get_element_producer_fn, - MultiprocessingOptions(num_workers=2), - 0, - ) as iterator: - with self.assertRaisesRegex( - RuntimeError, "was terminated unexpectedly with exit code 12" - ): - list(iterator) - - def test_reports_unpicklable_element_producer_fn(self): - error_msg = "UnpicklableObject is not picklable" - - class UnpicklableObject: - - def __getstate__(self): - raise ValueError(error_msg) - - local_state = UnpicklableObject() - - class GetElementProducerFnWithUnpicklableClosure(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_index, worker_count - yield 1 if local_state is None else 2 - - get_element_producer_fn_with_unpicklable_closure = ( - GetElementProducerFnWithUnpicklableClosure() - ) - - with gp.MultiProcessIterator( - get_element_producer_fn_with_unpicklable_closure, - MultiprocessingOptions(num_workers=2), - 0, - ) as iterator: - with self.assertRaisesRegex(ValueError, error_msg): - list(iterator) - - def test_worker_init_fn(self): - - def _set_worker_index_and_count(worker_index: int, worker_count: int): - gp.monkey_patched_index_and_count = (worker_index, worker_count) - - class GetElementProducerFnReturningGlobal(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[tuple[int, int]]: - del self, worker_index, worker_count - yield gp.monkey_patched_index_and_count # pytype: disable=module-attr - - with gp.MultiProcessIterator( - GetElementProducerFnReturningGlobal(), - MultiprocessingOptions(num_workers=2), - 0, - worker_init_fn=_set_worker_index_and_count, - ) as iterator: - result = list(iterator) - self.assertEqual(result, [(0, 2), (1, 2)]) - - -if __name__ == "__main__": - absltest.main() diff --git a/grain/_src/python/shared_memory_array.py b/grain/_src/python/shared_memory_array.py index 37f37175e..a7ef243e2 100644 --- a/grain/_src/python/shared_memory_array.py +++ b/grain/_src/python/shared_memory_array.py @@ -222,3 +222,12 @@ def _open_leaf_from_shm(leaf: Any) -> Any: def open_from_shm(struct: Any) -> Any: """Recovers leaf ndarrays of the structure from shared memory.""" return tree_lib.map_structure(_open_leaf_from_shm, struct) + + +def _unlink_shm_if_metadata(obj: Any) -> None: + if isinstance(obj, SharedMemoryArrayMetadata): + obj.close_and_unlink_shm() + + +def unlink_shm(struct: Any) -> None: + tree_lib.map_structure(_unlink_shm_if_metadata, struct) diff --git a/grain/_src/python/shared_memory_array_test.py b/grain/_src/python/shared_memory_array_test.py index 4d4b8f2a2..65994882f 100644 --- a/grain/_src/python/shared_memory_array_test.py +++ b/grain/_src/python/shared_memory_array_test.py @@ -28,6 +28,7 @@ from grain._src.python.shared_memory_array import open_from_shm from grain._src.python.shared_memory_array import SharedMemoryArray from grain._src.python.shared_memory_array import SharedMemoryArrayMetadata +from grain._src.python.shared_memory_array import unlink_shm import jax import numpy as np @@ -226,6 +227,26 @@ def test_copy_and_open_shm_min_size(self): np.testing.assert_array_equal(opened_struct, arr) self.assertTrue(opened_struct._unlink_on_del) + def test_unlink_shm(self): + arr = np.arange(10).astype(np.int32) + arr2 = np.arange(5).astype(np.int32) + struct = {"a": arr, "b": [arr2, arr], "c": 123} + shm_struct = copy_to_shm(struct) + names = [ + shm_struct["a"].name, + shm_struct["b"][0].name, + shm_struct["b"][1].name, + ] + # Check that SHMs exist. + for name in names: + self.assertIsNotNone(shared_memory.SharedMemory(name=name, create=False)) + + unlink_shm(shm_struct) + + for name in names: + with self.assertRaises(FileNotFoundError): + shared_memory.SharedMemory(name=name, create=False) + if __name__ == "__main__": absltest.main()