From bc8225699ce5a205fae214e4cb5fa438474ac7f8 Mon Sep 17 00:00:00 2001 From: Ihor Indyk Date: Mon, 15 Dec 2025 14:23:21 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 844919525 --- grain/_src/python/dataset/dataset.py | 30 ++++++++++++++++++++++++++++ grain/_src/python/dataset/stats.py | 25 ++--------------------- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index ba04c95b..790f2470 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -50,6 +50,7 @@ from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence import functools import json +import time from typing import Any, Generic, TypeVar, Union, cast, overload import warnings @@ -75,6 +76,20 @@ fields=[("name", str)], ) +_next_duration_ns_histogram = monitoring.EventMetric( + "/grain/python/dataset/next_duration_ns", + metadata=monitoring.Metadata( + description=( + "Histogram of durations of every `__next__` call on the output" + " iterator. Each data point is the duration value of `__next__`" + " call." + ), + units=monitoring.Units.NANOSECONDS, + ), + root=grain_monitoring.get_monitoring_root(), + bucketer=monitoring.Bucketer.PowersOf(2.0), +) + T = TypeVar("T") S = TypeVar("S") @@ -1683,6 +1698,19 @@ def is_thread_prefetch_injection_enabled() -> bool: return False +def _record_next_duration(next_fn): + """Records the duration of the `__next__` call on the output iterator node.""" + + @functools.wraps(next_fn) + def wrapper(): + start_time = time.perf_counter_ns() + result = next_fn() + _next_duration_ns_histogram.Record(time.perf_counter_ns() - start_time) + return result + + return wrapper + + class _OutputIterDataset(IterDataset[T]): """Dataset that is injected at the end of every pipeline.""" @@ -1700,6 +1728,8 @@ def __iter__(self) -> DatasetIterator[T]: ): if not prefetch.is_prefetch_iterator(iterator): iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1) + # Wrap the __next__ function to record the duration of the call. + iterator.__next__ = _record_next_duration(iterator.__next__) return iterator diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index 269cefe5..3f2731d9 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -62,20 +62,6 @@ bucketer=monitoring.Bucketer.PowersOf(2.0), ) -_next_duration_ns_histogram = monitoring.EventMetric( - "/grain/python/dataset/next_duration_ns", - metadata=monitoring.Metadata( - description=( - "Histogram of durations of every `__next__` call on the output" - " iterator. Each data point is the duration value of `__next__`" - " call." - ), - units=monitoring.Units.NANOSECONDS, - ), - root=grain_monitoring.get_monitoring_root(), - bucketer=monitoring.Bucketer.PowersOf(2.0), -) - T = TypeVar("T") # Time between two consecutive monitoring reports. _REPORTING_PERIOD_SEC = 5 @@ -318,16 +304,9 @@ def wrapper(iterator): _ipl_stage_name=str(iterator), _ipl_stage_id=id(iterator), ): - start_time = time.perf_counter_ns() - result = next_fn(iterator) + return next_fn(iterator) else: - start_time = time.perf_counter_ns() - result = next_fn(iterator) - - if iterator._stats._is_output: # pylint:disable=protected-access - next_duration_ns = time.perf_counter_ns() - start_time - _next_duration_ns_histogram.Record(next_duration_ns) - return result + return next_fn(iterator) return wrapper