From 30dff331c475767821d7deb77e7ec7e1cc2b6e5a Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Thu, 8 Jan 2026 16:56:57 -0500 Subject: [PATCH] Fix type annotations for `ArrayRecordDataSource`. --- grain/_src/python/data_sources.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/grain/_src/python/data_sources.py b/grain/_src/python/data_sources.py index 6ec74217..2ff89ec5 100644 --- a/grain/_src/python/data_sources.py +++ b/grain/_src/python/data_sources.py @@ -46,6 +46,12 @@ class ARDataSource: def __init__(self, *args, **kwargs): raise RuntimeError("array_record isn't supported on Windows") + + def __len__(self) -> int: + raise RuntimeError("array_record isn't supported on Windows") + + def __getitem__(self, index: int) -> bytes: + raise RuntimeError("array_record isn't supported on Windows") else: from array_record.python.array_record_data_source import ( ArrayRecordDataSource as ARDataSource, @@ -111,8 +117,8 @@ def __init__( _api_usage_counter.Increment("ArrayRecordDataSource") @dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ) - def __getitem__(self, record_key: SupportsIndex) -> bytes: - data = super().__getitem__(record_key) + def __getitem__(self, index: int) -> bytes: + data = super().__getitem__(index) _bytes_read_counter.IncrementBy(len(data), "ArrayRecordDataSource") return data