Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions grain/_src/core/tree_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,24 @@ def _shape(obj):
structure,
is_leaf=_is_leaf,
)


def estimate_byte_size(structure) -> int:
"""Returns estimated total byte size of the given tree."""
# Fast path for serialized data.
if isinstance(structure, (bytes, str)):
return len(structure)

result = 0

# This is intentionally very light and only handles the the most common types
# that have a non-trivial byte size to avoid overheads.
def add_byte_size(x):
nonlocal result
if isinstance(x, (bytes, str)):
result += len(x)
elif isinstance(x, np.ndarray):
result += x.nbytes

map_structure(add_byte_size, structure)
return result
46 changes: 46 additions & 0 deletions grain/_src/core/tree_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def unflatten_as(self, structure, flat_sequence):
def spec_like(self, structure):
...

def estimate_byte_size(self, structure):
...


# Static check that the module implements the necessary functions.
tree_lib: TreeImpl = tree_lib
Expand Down Expand Up @@ -193,6 +196,49 @@ def test_spec_like_with_attrs(self):
],
)

@parameterized.named_parameters(
dict(
testcase_name="bytes",
structure=b"serialized_data",
expected_output=15,
),
dict(
testcase_name="ndarray",
structure=np.asarray([1, 2, 3], dtype=np.int32),
expected_output=3 * 4,
),
dict(
testcase_name="int",
structure=1,
expected_output=0,
),
dict(
testcase_name="class",
structure=MyClass(1),
expected_output=0,
),
dict(
testcase_name="simple",
structure={
"A": "v2",
"B": 1232.4,
"C": np.asarray([1, 2, 3], dtype=np.float32),
},
expected_output=2 + 3 * 4,
),
dict(
testcase_name="nested",
structure={
"A": "v2",
"B": {"C": np.asarray([1, 2], dtype=np.float32)},
"c": MyClass(b"asdsad"),
},
expected_output=2 + 2 * 4,
),
)
def test_estimate_byte_size(self, structure, expected_output):
self.assertEqual(tree_lib.estimate_byte_size(structure), expected_output)


if __name__ == "__main__":
absltest.main()
2 changes: 1 addition & 1 deletion grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ py_library(
srcs_version = "PY3",
deps = [
"//grain/_src/core:monitoring",
"@pypi//etils:pkg",
"@abseil-py//absl/logging",
"@pypi//etils:pkg",
"//grain/_src/python/dataset:stats",
] + select({
"@platforms//os:windows": [],
Expand Down
8 changes: 4 additions & 4 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,9 @@ def __init__(
sampler: Sampler,
shard_options: sharding.ShardOptions,
):
super().__init__()
super().__init__(dataset.MapDataset.source(data_source))
self._sampler = sampler
self._shard_options = shard_options
self._data_source = data_source
self.length = self._sampler_size() // self._shard_options.shard_count

def _sampler_size(self) -> int:
Expand Down Expand Up @@ -163,8 +162,9 @@ def __getitem__(self, index):
)
with self._stats.record_self_time():
metadata = self._sampler[index]
data = self._data_source[metadata.record_key]
return record.Record(metadata=metadata, data=data)
return record.Record(
metadata=metadata, data=self._parent[metadata.record_key]
)


class _OperationIterDataset(dataset.IterDataset[_T]):
Expand Down
14 changes: 1 addition & 13 deletions grain/_src/python/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ def __init__(self, *args, **kwargs):
root=grain_monitoring.get_monitoring_root(),
fields=[("name", str)],
)
_bytes_read_counter = monitoring.Counter(
"/grain/python/data_sources/bytes_read",
monitoring.Metadata(
description=(
"Number of bytes produced by a data source via random access."
),
),
root=grain_monitoring.get_monitoring_root(),
fields=[("source", str)],
)

T = TypeVar("T")
ArrayRecordDataSourcePaths = Union[
Expand Down Expand Up @@ -112,9 +102,7 @@ def __init__(

@dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ)
def __getitem__(self, record_key: SupportsIndex) -> bytes:
data = super().__getitem__(record_key)
_bytes_read_counter.IncrementBy(len(data), "ArrayRecordDataSource")
return data
return super().__getitem__(record_key)

@property
def paths(self) -> ArrayRecordDataSourcePaths:
Expand Down
82 changes: 38 additions & 44 deletions grain/_src/python/dataset/transformations/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,52 +16,30 @@
from __future__ import annotations

import contextlib # pylint: disable=unused-import
import threading
import functools
import time
from typing import Any, Sequence, Union

from absl import logging
from grain._src.core import monitoring as grain_monitoring

from grain._src.core import sharding
from grain._src.python import options
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
import numpy as np

from grain._src.core import monitoring


_source_read_time_ns_histogram = monitoring.EventMetric(
"/grain/python/dataset/source_read_time_ns",
metadata=monitoring.Metadata(
description="Histogram of source read time in nanoseconds.",
units=monitoring.Units.NANOSECONDS,
),
root=grain_monitoring.get_monitoring_root(),
fields=[("source", str)],
bucketer=monitoring.Bucketer.PowersOf(2.0),
)

_metric_lock = threading.Lock()


def _maybe_record_source_read_time(
elapsed_time_ns: int, source_name: str
) -> None:
"""Records the source read time in nanoseconds if metric lock is available.

To avoid contention and potential slowness, we only record the time if the
lock is immediately available.

Args:
elapsed_time_ns: The elapsed time in nanoseconds.
source_name: The name of the source.
"""

if _metric_lock.acquire(blocking=False):
_source_read_time_ns_histogram.Record(elapsed_time_ns, source_name)
_metric_lock.release()
def _has_source_parent(ds: dataset.MapDataset) -> bool:
to_check = [ds]
while to_check:
next_ds = to_check.pop()
if isinstance(next_ds, SourceMapDataset):
return True
# Custom user Dataset implementations do not always call super().__init__()
# which leads to `_parents` not being set.
to_check.extend(getattr(next_ds, "_parents", ()))
return False


class SourceMapDataset(dataset.MapDataset):
Expand All @@ -71,30 +49,44 @@ def __init__(self, source: base.RandomAccessDataSource):
super().__init__()
self._source = source
self._original_source_map_dataset = None
# Sometimes users wrap the source into a `MapDataset`, we don't want to
# double count metrics for these cases.
if isinstance(source, dataset.MapDataset):
self._record_metrics = not _has_source_parent(source)
else:
self._record_metrics = True

def __len__(self) -> int:
return len(self._source)

def _index_mod_len(self, index: int) -> int:
# Legacy source implementations sometimes return None for source length.
# We'll let this case fail if the length is requested, but let the pipeline
# read from the source by index without mod.
try:
return index % len(self._source)
except TypeError:
return index

def __str__(self) -> str:
return f"SourceMapDataset(source={self._source.__class__.__name__})"
return f"SourceMapDataset(source={self._source_name})"

@functools.cached_property
def _source_name(self) -> str:
return self._source.__class__.__name__

@dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ)
def __getitem__(self, index):
if isinstance(index, slice):
return self.slice(index)
return self._instrumented_getitem(index)

def _instrumented_getitem(self, index):
"""Instrumented __getitem__ private implementation."""
tagging_ctx = contextlib.nullcontext()
with tagging_ctx:
with self._stats.record_self_time():
start_time = time.perf_counter_ns()
result = self._stats.record_output_spec(self._source[index % len(self)])
stop_time = time.perf_counter_ns()
_maybe_record_source_read_time(
stop_time - start_time, self._source.__class__.__name__
result = self._stats.record_output_spec(
self._source[self._index_mod_len(index)]
)
stop_time = time.perf_counter_ns()
return result

def _getitems(self, indices: Sequence[int]):
Expand All @@ -103,9 +95,11 @@ def _getitems(self, indices: Sequence[int]):
):
return super()._getitems(indices)
with self._stats.record_self_time(num_elements=len(indices)):
start_time = time.perf_counter_ns()
elements = self._source._getitems( # pylint: disable=protected-access
[index % len(self) for index in indices]
[self._index_mod_len(index) for index in indices]
)
stop_time = time.perf_counter_ns()
return self._stats.record_output_spec_for_batch(elements)

def _get_sequential_slice(self, sl: slice) -> slice:
Expand Down