diff --git a/grain/_src/core/tree_lib.py b/grain/_src/core/tree_lib.py index 1875187c..75e48674 100644 --- a/grain/_src/core/tree_lib.py +++ b/grain/_src/core/tree_lib.py @@ -314,3 +314,24 @@ def _shape(obj): structure, is_leaf=_is_leaf, ) + + +def estimate_byte_size(structure) -> int: + """Returns estimated total byte size of the given tree.""" + # Fast path for serialized data. + if isinstance(structure, (bytes, str)): + return len(structure) + + result = 0 + + # This is intentionally very light and only handles the the most common types + # that have a non-trivial byte size to avoid overheads. + def add_byte_size(x): + nonlocal result + if isinstance(x, (bytes, str)): + result += len(x) + elif isinstance(x, np.ndarray): + result += x.nbytes + + map_structure(add_byte_size, structure) + return result diff --git a/grain/_src/core/tree_lib_test.py b/grain/_src/core/tree_lib_test.py index 4c8315d7..836bb880 100644 --- a/grain/_src/core/tree_lib_test.py +++ b/grain/_src/core/tree_lib_test.py @@ -52,6 +52,9 @@ def unflatten_as(self, structure, flat_sequence): def spec_like(self, structure): ... + def estimate_byte_size(self, structure): + ... + # Static check that the module implements the necessary functions. tree_lib: TreeImpl = tree_lib @@ -193,6 +196,49 @@ def test_spec_like_with_attrs(self): ], ) + @parameterized.named_parameters( + dict( + testcase_name="bytes", + structure=b"serialized_data", + expected_output=15, + ), + dict( + testcase_name="ndarray", + structure=np.asarray([1, 2, 3], dtype=np.int32), + expected_output=3 * 4, + ), + dict( + testcase_name="int", + structure=1, + expected_output=0, + ), + dict( + testcase_name="class", + structure=MyClass(1), + expected_output=0, + ), + dict( + testcase_name="simple", + structure={ + "A": "v2", + "B": 1232.4, + "C": np.asarray([1, 2, 3], dtype=np.float32), + }, + expected_output=2 + 3 * 4, + ), + dict( + testcase_name="nested", + structure={ + "A": "v2", + "B": {"C": np.asarray([1, 2], dtype=np.float32)}, + "c": MyClass(b"asdsad"), + }, + expected_output=2 + 2 * 4, + ), + ) + def test_estimate_byte_size(self, structure, expected_output): + self.assertEqual(tree_lib.estimate_byte_size(structure), expected_output) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index cb57e011..8ec40222 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -10,8 +10,8 @@ py_library( srcs_version = "PY3", deps = [ "//grain/_src/core:monitoring", - "@pypi//etils:pkg", "@abseil-py//absl/logging", + "@pypi//etils:pkg", "//grain/_src/python/dataset:stats", ] + select({ "@platforms//os:windows": [], diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index fffd939b..2b4d641e 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -128,10 +128,9 @@ def __init__( sampler: Sampler, shard_options: sharding.ShardOptions, ): - super().__init__() + super().__init__(dataset.MapDataset.source(data_source)) self._sampler = sampler self._shard_options = shard_options - self._data_source = data_source self.length = self._sampler_size() // self._shard_options.shard_count def _sampler_size(self) -> int: @@ -163,8 +162,9 @@ def __getitem__(self, index): ) with self._stats.record_self_time(): metadata = self._sampler[index] - data = self._data_source[metadata.record_key] - return record.Record(metadata=metadata, data=data) + return record.Record( + metadata=metadata, data=self._parent[metadata.record_key] + ) class _OperationIterDataset(dataset.IterDataset[_T]): diff --git a/grain/_src/python/data_sources.py b/grain/_src/python/data_sources.py index 6ec74217..ca15bf33 100644 --- a/grain/_src/python/data_sources.py +++ b/grain/_src/python/data_sources.py @@ -59,16 +59,6 @@ def __init__(self, *args, **kwargs): root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], ) -_bytes_read_counter = monitoring.Counter( - "/grain/python/data_sources/bytes_read", - monitoring.Metadata( - description=( - "Number of bytes produced by a data source via random access." - ), - ), - root=grain_monitoring.get_monitoring_root(), - fields=[("source", str)], -) T = TypeVar("T") ArrayRecordDataSourcePaths = Union[ @@ -112,9 +102,7 @@ def __init__( @dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ) def __getitem__(self, record_key: SupportsIndex) -> bytes: - data = super().__getitem__(record_key) - _bytes_read_counter.IncrementBy(len(data), "ArrayRecordDataSource") - return data + return super().__getitem__(record_key) @property def paths(self) -> ArrayRecordDataSourcePaths: diff --git a/grain/_src/python/dataset/transformations/source.py b/grain/_src/python/dataset/transformations/source.py index f7da76bd..4ba0b3ed 100644 --- a/grain/_src/python/dataset/transformations/source.py +++ b/grain/_src/python/dataset/transformations/source.py @@ -16,12 +16,12 @@ from __future__ import annotations import contextlib # pylint: disable=unused-import -import threading +import functools import time from typing import Any, Sequence, Union from absl import logging -from grain._src.core import monitoring as grain_monitoring + from grain._src.core import sharding from grain._src.python import options from grain._src.python.dataset import base @@ -29,39 +29,17 @@ from grain._src.python.dataset import stats as dataset_stats import numpy as np -from grain._src.core import monitoring - - -_source_read_time_ns_histogram = monitoring.EventMetric( - "/grain/python/dataset/source_read_time_ns", - metadata=monitoring.Metadata( - description="Histogram of source read time in nanoseconds.", - units=monitoring.Units.NANOSECONDS, - ), - root=grain_monitoring.get_monitoring_root(), - fields=[("source", str)], - bucketer=monitoring.Bucketer.PowersOf(2.0), -) - -_metric_lock = threading.Lock() - - -def _maybe_record_source_read_time( - elapsed_time_ns: int, source_name: str -) -> None: - """Records the source read time in nanoseconds if metric lock is available. - - To avoid contention and potential slowness, we only record the time if the - lock is immediately available. - Args: - elapsed_time_ns: The elapsed time in nanoseconds. - source_name: The name of the source. - """ - - if _metric_lock.acquire(blocking=False): - _source_read_time_ns_histogram.Record(elapsed_time_ns, source_name) - _metric_lock.release() +def _has_source_parent(ds: dataset.MapDataset) -> bool: + to_check = [ds] + while to_check: + next_ds = to_check.pop() + if isinstance(next_ds, SourceMapDataset): + return True + # Custom user Dataset implementations do not always call super().__init__() + # which leads to `_parents` not being set. + to_check.extend(getattr(next_ds, "_parents", ())) + return False class SourceMapDataset(dataset.MapDataset): @@ -71,30 +49,44 @@ def __init__(self, source: base.RandomAccessDataSource): super().__init__() self._source = source self._original_source_map_dataset = None + # Sometimes users wrap the source into a `MapDataset`, we don't want to + # double count metrics for these cases. + if isinstance(source, dataset.MapDataset): + self._record_metrics = not _has_source_parent(source) + else: + self._record_metrics = True def __len__(self) -> int: return len(self._source) + def _index_mod_len(self, index: int) -> int: + # Legacy source implementations sometimes return None for source length. + # We'll let this case fail if the length is requested, but let the pipeline + # read from the source by index without mod. + try: + return index % len(self._source) + except TypeError: + return index + def __str__(self) -> str: - return f"SourceMapDataset(source={self._source.__class__.__name__})" + return f"SourceMapDataset(source={self._source_name})" + + @functools.cached_property + def _source_name(self) -> str: + return self._source.__class__.__name__ @dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ) def __getitem__(self, index): if isinstance(index, slice): return self.slice(index) - return self._instrumented_getitem(index) - - def _instrumented_getitem(self, index): - """Instrumented __getitem__ private implementation.""" tagging_ctx = contextlib.nullcontext() with tagging_ctx: with self._stats.record_self_time(): start_time = time.perf_counter_ns() - result = self._stats.record_output_spec(self._source[index % len(self)]) - stop_time = time.perf_counter_ns() - _maybe_record_source_read_time( - stop_time - start_time, self._source.__class__.__name__ + result = self._stats.record_output_spec( + self._source[self._index_mod_len(index)] ) + stop_time = time.perf_counter_ns() return result def _getitems(self, indices: Sequence[int]): @@ -103,9 +95,11 @@ def _getitems(self, indices: Sequence[int]): ): return super()._getitems(indices) with self._stats.record_self_time(num_elements=len(indices)): + start_time = time.perf_counter_ns() elements = self._source._getitems( # pylint: disable=protected-access - [index % len(self) for index in indices] + [self._index_mod_len(index) for index in indices] ) + stop_time = time.perf_counter_ns() return self._stats.record_output_spec_for_batch(elements) def _get_sequential_slice(self, sl: slice) -> slice: