Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,28 +402,31 @@ def __del__(self):

class ParallelMapper(BaseNode[T]):
"""ParallelMapper executes map_fn in parallel either in num_workers threads or
processes. For processes, multiprocessing_context can be spawn, forkserver, fork,
or None (chooses OS default). At most max_concurrent items will be either processed
or in the iterator's output queue, to limit CPU and Memory utilization. If None
(default) the value will be 2 * num_workers.
processes. For processes, multiprocessing_context can be spawn, forkserver, fork,
or None (chooses OS default). At most max_concurrent items will be either processed
or in the iterator's output queue, to limit CPU and Memory utilization. If None
(default) the value will be 2 * num_workers.

At most one iter() is created from source, and at most one thread will call
next() on it at once.
At most one iter() is created from source, and at most one thread will call
next() on it at once.

If in_order is true, the iterator will return items in the order from which they arrive
If in_order is true, the iterator will return items in the order from which they arrive
from source's iterator, potentially blocking even if other items are available.

.. warning::
When ``in_order=False``, ParallelMapper does not guarantee reproducible
ordering or state across runs, even with identical inputs.

Args:
source (BaseNode[X]): The source node to map over.
map_fn (Callable[[X], T]): The function to apply to each item from the source node.
num_workers (int): The number of workers to use for parallel processing.
in_order (bool): Whether to return items in the order from which they arrive from. Default is True.
method (Literal["thread", "process"]): The method to use for parallel processing. Default is "thread".
multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None.
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
prebatch (Optional[int]): Optionally perform pre-batching of items from source before mapping.
For small items, this may improve throughput at the expense of peak memory.
source (BaseNode[X]): The source node to map over.
map_fn (Callable[[X], T]): The function to apply to each item from the source node.
num_workers (int): The number of workers to use for parallel processing.
in_order (bool): Whether to return items in order. Default is True.
method (Literal["thread", "process"]): The method to use for parallel processing. Default is "thread".
multiprocessing_context (Optional[str]): The multiprocessing context to use. Default is None.
max_concurrent (Optional[int]): The maximum number of concurrent items. Default is None (2 * num_workers).
snapshot_frequency (int): The frequency at which to snapshot state. Default is 1.
prebatch (Optional[int]): Optionally batch items before mapping. Default is None.
"""

IT_STATE_KEY = "it_state"
Expand Down