Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,49 +208,6 @@ py_test(
],
)

py_library(
name = "grain_pool",
srcs = ["grain_pool.py"],
srcs_version = "PY3",
target_compatible_with = select({
"@platforms//os:windows": ["@platforms//:incompatible"],
"//conditions:default": [],
}),
deps = [
":grain_logging",
":multiprocessing_common",
":options",
":record",
":shared_memory_array",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:parallel",
"//grain/_src/core:tree_lib",
"@abseil-py//absl/flags",
"@abseil-py//absl/logging",
"@pypi//cloudpickle:pkg",
],
)

py_test(
name = "grain_pool_test",
srcs = ["grain_pool_test.py"],
shard_count = 20,
srcs_version = "PY3",
tags = ["not_run:arm"],
deps = [
":data_sources",
":grain_pool",
":options",
":record",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"@abseil-py//absl/flags",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
],
)

py_library(
name = "checkpoint_handlers",
srcs = ["checkpoint_handlers.py"],
Expand Down
84 changes: 33 additions & 51 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
import sys
from typing import Any, Awaitable, Optional, Sequence, TypeVar

from absl import logging
from etils import epath
from grain._src.core import monitoring as grain_monitoring
from grain._src.core import sharding
from grain._src.core import transforms
from grain._src.core import tree_lib
import multiprocessing as mp
from grain._src.python import checkpointing
from grain._src.python import operations as ops
from grain._src.python import options
Expand All @@ -39,8 +37,6 @@
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import batch as batch_ds
from grain._src.python.dataset.transformations import flatmap
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.operations import BatchOperation
from grain._src.python.operations import Operation
from grain._src.python.samplers import Sampler
from grain._src.python.shared_memory_array import SharedMemoryArray
Expand Down Expand Up @@ -260,14 +256,15 @@ def get_state(self):
# structure and switch to native dataset checkpointing. This class can be
# removed afterwards.
dataset_state = self._parent.get_state()
if "workers_state" not in dataset_state:
dataset_state = {
"workers_state": {"0": dataset_state},
"last_worker_index": -1,
}
workers_state = dataset_state["workers_state"]
last_worker_index = dataset_state["last_worker_index"]
worker_count = len(dataset_state["workers_state"])
if "iterators_in_use_states" not in dataset_state:
next_index_in_cycle = 0
workers_state = [dataset_state]
else:
next_index_in_cycle = dataset_state["next_index_in_cycle"]
workers_state = dataset_state["iterators_in_use_states"]

last_worker_index = next_index_in_cycle - 1
worker_count = len(workers_state)

shard_index = self._shard_options.shard_index if self._shard_options else 0
shard_count = self._shard_options.shard_count if self._shard_options else 1
Expand All @@ -278,7 +275,7 @@ def get_state(self):
str(i): (
local_offset
+ i * shard_count
+ workers_state[str(i)]["next_index"] * global_worker_count
+ workers_state[i]["next_index"] * global_worker_count
)
for i in range(worker_count)
}
Expand Down Expand Up @@ -311,25 +308,28 @@ def set_state(self, state):
self._parent.set_state(dataset_state)
return

iterations_to_skip = {str(i): 0 for i in range(worker_count)}
workers_state = {
str(i): {
iterators_in_use_indices = list(range(worker_count))
iterators_in_use_states = [
{
"next_index": (
(
last_seen_indices[str(i)]
last_seen_indices[str(worker_index)]
+ global_worker_count
- shard_index
- i * shard_count
- worker_index * shard_count
)
// global_worker_count
)
}
for i in range(worker_count)
}
for worker_index in range(worker_count)
]

dataset_state = {
"workers_state": workers_state,
"iterations_to_skip": iterations_to_skip,
"last_worker_index": last_worker_index,
"next_index_in_cycle": last_worker_index + 1 % worker_count,
"next_index_in_datasets": worker_count,
"iterators_in_use_indices": iterators_in_use_indices,
"iterators_in_use_states": iterators_in_use_states,
"exhausted": [False] * worker_count,
}
self._parent.set_state(dataset_state)

Expand Down Expand Up @@ -389,31 +389,17 @@ def __init__(
f"Current worker_buffer_size is {worker_buffer_size}."
)

worker_count = _determine_worker_count(worker_count)
if worker_count > 0:

# Shared memory should be enabled iff worker_count > 0.
# This replaces Batch Transform with a BatchOperation in operations list
# if shared memory is enabled.
if operations and isinstance(
(last_op := operations[-1]), transforms.Batch
):
logging.info("Creating BatchOperation to enable SharedMemoryArray.")
batch_operation = BatchOperation(
batch_size=last_op.batch_size,
drop_remainder=last_op.drop_remainder,
batch_fn=last_op.batch_fn,
operations = list(operations)
for i in range(len(operations)):
op = operations[i]
if type(op) is ops.BatchOperation: # pylint: disable=unidiomatic-typecheck
operations[i] = transforms.Batch(
batch_size=op.batch_size,
drop_remainder=op.drop_remainder,
batch_fn=op.batch_fn,
)
batch_operation.disable_deprecation_message()
operations = list(operations)
operations[-1] = batch_operation

if operations and isinstance(operations[-1], BatchOperation):
logging.info("Enabling SharedMemoryArray for BatchOperation.")
operations[-1]._enable_shared_memory()
else:
logging.info("Adding CopyNumPyArrayToSharedMemory MapTransform.")
operations = list(operations) + [CopyNumPyArrayToSharedMemory()]
worker_count = _determine_worker_count(worker_count)

self._data_source = data_source
self._sampler = sampler
Expand Down Expand Up @@ -471,11 +457,7 @@ def _create_dataset(self) -> dataset.IterDataset:
ds = _apply_transform_to_dataset(operation, ds)
ds = ds.map(lambda r: r.data)
if self.multiprocessing_options.num_workers > 0:
ds = prefetch.MultiprocessPrefetchIterDataset(
ds,
self.multiprocessing_options,
always_report_worker_state=True,
)
ds = ds.mp_prefetch(self.multiprocessing_options)
if not self._use_native_dataset_checkpointing:
ds = _DataLoaderStateIterDataset(
ds,
Expand Down
76 changes: 0 additions & 76 deletions grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,82 +705,6 @@ def test_batch_transform_mapped_to_batch_operation(self):
actual = list(data_loader)
np.testing.assert_equal(actual, expected)

@mock.patch.object(data_loader_lib, "CopyNumPyArrayToSharedMemory")
def test_shared_memory_for_batch_operation(
self, mock_copy_numpy_array_to_shared_memory
):
range_data_source = RangeDataSource(start=0, stop=8, step=1)
sampler = samplers.SequentialSampler(
num_records=len(range_data_source), shard_options=sharding.NoSharding()
)

operations = [
PlusOne(),
FilterEven(),
]

batch_operation = mock.MagicMock(BatchOperation(batch_size=2))

data_loader = data_loader_lib.DataLoader(
data_source=range_data_source,
sampler=sampler,
operations=operations,
worker_count=0,
read_options=self.read_options,
)
batch_operation._enable_shared_memory.assert_not_called()
self.assertTrue(
data_loader._operations[-1], mock_copy_numpy_array_to_shared_memory
)

data_loader = data_loader_lib.DataLoader(
data_source=range_data_source,
sampler=sampler,
operations=operations + [batch_operation],
worker_count=2,
read_options=self.read_options,
)
batch_operation._enable_shared_memory.assert_called_once()
self.assertTrue(data_loader._operations[-1], batch_operation)

@mock.patch.object(BatchOperation, "_enable_shared_memory", autospec=True)
def test_shared_memory_for_batch_transform(self, mock_enable_shared_memory):
range_data_source = RangeDataSource(start=0, stop=8, step=1)
sampler = samplers.SequentialSampler(
num_records=len(range_data_source), shard_options=sharding.NoSharding()
)
operations = [
PlusOne(),
FilterEven(),
]

data_loader = data_loader_lib.DataLoader(
data_source=range_data_source,
sampler=sampler,
operations=operations,
worker_count=2,
read_options=self.read_options,
)
mock_enable_shared_memory.assert_not_called()
self.assertIsInstance(
data_loader._operations[-1],
data_loader_lib.CopyNumPyArrayToSharedMemory,
)

batch_transform = transforms.Batch(batch_size=2)

data_loader = data_loader_lib.DataLoader(
data_source=range_data_source,
sampler=sampler,
operations=operations + [batch_transform],
worker_count=2,
read_options=self.read_options,
)
mock_enable_shared_memory.assert_called_once_with(
data_loader._operations[-1]
)
self.assertIsInstance(data_loader._operations[-1], BatchOperation)

def test_data_loader_with_batch_fn(self):
# Map transforms elements to be [1, 2, 3, 4, 5, 6, 7, 8]
# Filter keeps only even elements [2, 4, 6, 8]
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ py_library(
"//grain/_src/core:tree_lib",
"//grain/_src/python:checkpointing",
"//grain/_src/python:grain_logging",
"//grain/_src/python:grain_pool",
"//grain/_src/python:multiprocessing_common",
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
"//grain/proto:execution_summary_py_pb2",
Expand Down
9 changes: 5 additions & 4 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,13 +1353,14 @@ def mp_prefetch(
A dataset prefetching input elements in separate processes.
"""
options = options or grain_options.MultiprocessingOptions(num_workers=10)
# Loaded lazily due to a circular dependency (dataset <-> prefetch).
# Loaded lazily due to a circular dependency (dataset <-> process_prefetch).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.dataset.transformations import process_prefetch
# pylint: enable=g-import-not-at-top
return prefetch.MultiprocessPrefetchIterDataset(
return process_prefetch.multiprocess_prefetch(
self,
multiprocessing_options=options,
num_workers=options.num_workers,
buffer_size=options.per_worker_buffer_size,
worker_init_fn=worker_init_fn,
sequential_slice=sequential_slice,
)
Expand Down
34 changes: 0 additions & 34 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,40 +1393,6 @@ def test_execution_summary_with_no_logging(self):
log_value = "Grain Dataset Execution Summary"
self.assertNotIn(log_value, "".join(logs.output))

@flagsaver.flagsaver(grain_py_debug_mode=True)
@mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05)
def test_execution_summary_with_mp_prefetch(self):
def worker_init_fn_wrapper(worker_index, worker_count):
del worker_index, worker_count
dataset_stats._REPORTING_PERIOD_SEC = 0.05

ds = dataset.MapDataset.range(10000).map(MapTransformAddingOne())
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch(
options.MultiprocessingOptions(num_workers=1),
worker_init_fn=worker_init_fn_wrapper,
)
it = ds.__iter__()
_ = list(it)
all_nodes_present = False
while not all_nodes_present:
time.sleep(1)
all_nodes_present = True
summary = dataset.get_execution_summary(it)
node_names = {node.name for node in summary.nodes.values()}
all_nodes_present = all_nodes_present and any(
"RangeMapDataset" in name for name in node_names
)
all_nodes_present = all_nodes_present and any(
"MapMapDataset" in name for name in node_names
)
all_nodes_present = all_nodes_present and any(
"PrefetchDatasetIterator" in name for name in node_names
)
all_nodes_present = all_nodes_present and any(
"MultiprocessPrefetchDatasetIterator" in name for name in node_names
)


class GetElementSpecTest(parameterized.TestCase):

Expand Down
8 changes: 4 additions & 4 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
T = TypeVar("T")


class _InterleaveDatasetIterator(dataset.DatasetIterator[T]):
class InterleaveDatasetIterator(dataset.DatasetIterator[T]):
"""Iterates over the interleaved datasets."""

def __init__(
Expand Down Expand Up @@ -250,7 +250,7 @@ def __str__(self) -> str:

def _add_prefetch_and_make_iterator(
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
interleave_iterator: weakref.ref[_InterleaveDatasetIterator[T]],
interleave_iterator: weakref.ref[InterleaveDatasetIterator[T]],
start_prefetch: bool,
) -> dataset.DatasetIterator[T]:
"""Adds prefetching to an IterDataset and returns an iterator.
Expand Down Expand Up @@ -351,8 +351,8 @@ def __init__(
self._make_iter_buffer_size = make_iter_buffer_size
self._iter_buffer_size = iter_buffer_size

def __iter__(self) -> _InterleaveDatasetIterator[T]:
return _InterleaveDatasetIterator(
def __iter__(self) -> InterleaveDatasetIterator[T]:
return InterleaveDatasetIterator(
self._datasets,
cycle_length=self._cycle_length,
num_make_iter_threads=self._num_make_iter_threads,
Expand Down
Loading
Loading