From 1abb99a137cd20beb93a356d020f8289f74bd8a7 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <220467675+maxine-at-forecast@users.noreply.github.com> Date: Wed, 1 Apr 2026 13:03:18 -0700 Subject: [PATCH 1/2] refactor: extract Dataset._build_pipeline() to eliminate 4x pipeline duplication [CL-L3] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidate the duplicated pipeline construction from ordered() and shuffled() (4 code paths) into a single _build_pipeline() method. Fix bug where batched iteration silently dropped filter/map stages — filter and map are now always applied per-sample before batching. Add eager validation for batch_size < 1 (previously raised a cryptic IndexError from WDS internals). Co-Authored-By: Claude Opus 4.6 --- src/atdata/dataset.py | 103 ++++++++++++++++++++--------------- tests/test_dataset.py | 7 +-- tests/test_dev_experience.py | 43 +++++++++++++++ 3 files changed, 106 insertions(+), 47 deletions(-) diff --git a/src/atdata/dataset.py b/src/atdata/dataset.py index 3c77535..21aca2b 100644 --- a/src/atdata/dataset.py +++ b/src/atdata/dataset.py @@ -1089,6 +1089,61 @@ def _post_wrap_stages(self) -> list: stages.append(wds.filters.map(map_fn)) return stages + def _build_pipeline( + self, + *, + shuffle_shards: int | None = None, + shuffle_samples: int | None = None, + batch_size: int | None = None, + ) -> wds.pipeline.DataPipeline: + """Build the WebDataset pipeline with all configured stages. + + This is the single source of truth for pipeline construction. + ``ordered()`` and ``shuffled()`` delegate here instead of + duplicating the stage list. + + Args: + shuffle_shards: If set, shuffle shards with this buffer size. + shuffle_samples: If set, shuffle samples with this buffer size. + batch_size: If set, batch samples into groups of this size. + + Returns: + A fully-configured ``DataPipeline``. + """ + stages: list = [_ShardListStage(self._source)] + + if shuffle_shards is not None: + stages.append(wds.filters.shuffle(shuffle_shards)) + + stages += [ + wds.shardlists.split_by_worker, + _StreamOpenerStage(self._source), + wds.tariterators.tar_file_expander, + wds.tariterators.group_by_keys, + ] + + if shuffle_samples is not None: + stages.append(wds.filters.shuffle(shuffle_samples)) + + # Wrap raw WDS dicts into typed samples, then apply filter/map. + # Filter/map are always applied per-sample, before any batching. + stages.append(wds.filters.map(self.wrap)) + stages += self._post_wrap_stages() + + if batch_size is not None: + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}") + # collation_fn=None because items are already typed samples, + # not raw WDS dicts — default collation expects dicts. + stages.append(wds.filters.batched(batch_size, collation_fn=None)) + stages.append( + wds.filters.map( + lambda samples: SampleBatch[self.sample_type](samples) + ) + ) + + return wds.pipeline.DataPipeline(*stages) + @overload def ordered( self, @@ -1124,26 +1179,7 @@ def ordered( >>> for batch in ds.ordered(batch_size=32): ... process(batch) # batch is SampleBatch[ST] """ - if batch_size is None: - return wds.pipeline.DataPipeline( - _ShardListStage(self._source), - wds.shardlists.split_by_worker, - _StreamOpenerStage(self._source), - wds.tariterators.tar_file_expander, - wds.tariterators.group_by_keys, - wds.filters.map(self.wrap), - *self._post_wrap_stages(), - ) - - return wds.pipeline.DataPipeline( - _ShardListStage(self._source), - wds.shardlists.split_by_worker, - _StreamOpenerStage(self._source), - wds.tariterators.tar_file_expander, - wds.tariterators.group_by_keys, - wds.filters.batched(batch_size), - wds.filters.map(self.wrap_batch), - ) + return self._build_pipeline(batch_size=batch_size) @overload def shuffled( @@ -1193,29 +1229,10 @@ def shuffled( >>> for batch in ds.shuffled(batch_size=32): ... process(batch) # batch is SampleBatch[ST] """ - if batch_size is None: - return wds.pipeline.DataPipeline( - _ShardListStage(self._source), - wds.filters.shuffle(buffer_shards), - wds.shardlists.split_by_worker, - _StreamOpenerStage(self._source), - wds.tariterators.tar_file_expander, - wds.tariterators.group_by_keys, - wds.filters.shuffle(buffer_samples), - wds.filters.map(self.wrap), - *self._post_wrap_stages(), - ) - - return wds.pipeline.DataPipeline( - _ShardListStage(self._source), - wds.filters.shuffle(buffer_shards), - wds.shardlists.split_by_worker, - _StreamOpenerStage(self._source), - wds.tariterators.tar_file_expander, - wds.tariterators.group_by_keys, - wds.filters.shuffle(buffer_samples), - wds.filters.batched(batch_size), - wds.filters.map(self.wrap_batch), + return self._build_pipeline( + shuffle_shards=buffer_shards, + shuffle_samples=buffer_samples, + batch_size=batch_size, ) # Design note: Uses pandas for parquet export. Could be replaced with diff --git a/tests/test_dataset.py b/tests/test_dataset.py index be789c7..130c22c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -727,10 +727,9 @@ class BatchSizeSample: dataset = atdata.Dataset[BatchSizeSample](wds_filename) - # batch_size=0 produces empty batches, causing IndexError in webdataset's - # batched() when it tries to inspect the first element of an empty group - with pytest.raises(IndexError): - list(dataset.ordered(batch_size=0)) + # batch_size=0 is invalid — _build_pipeline validates eagerly + with pytest.raises(ValueError, match="batch_size must be >= 1"): + dataset.ordered(batch_size=0) ## diff --git a/tests/test_dev_experience.py b/tests/test_dev_experience.py index 5051d75..93361d8 100644 --- a/tests/test_dev_experience.py +++ b/tests/test_dev_experience.py @@ -279,6 +279,26 @@ def test_filter_chain(self, dev_tar): assert len(result) == 5 assert all(10 <= s.value < 15 for s in result) + def test_filter_with_batched_ordered(self, dev_tar): + """filter() is applied before batching in ordered iteration.""" + url, _ = dev_tar + ds = Dataset[DevSample](url) + filtered = ds.filter(lambda s: s.value >= 15) + batches = list(filtered.ordered(batch_size=3)) + all_samples = [s for batch in batches for s in batch.samples] + assert len(all_samples) == 5 + assert all(s.value >= 15 for s in all_samples) + + def test_filter_with_batched_shuffled(self, dev_tar): + """filter() is applied before batching in shuffled iteration.""" + url, _ = dev_tar + ds = Dataset[DevSample](url) + filtered = ds.filter(lambda s: s.value >= 15) + batches = list(filtered.shuffled(batch_size=3)) + all_samples = [s for batch in batches for s in batch.samples] + assert len(all_samples) == 5 + assert all(s.value >= 15 for s in all_samples) + class TestDatasetMap: def test_map(self, dev_tar): @@ -290,6 +310,29 @@ def test_map(self, dev_tar): assert all(isinstance(r, str) for r in result) assert result[0] == "s000" + def test_map_with_batched_ordered(self, dev_tar): + """map() is applied before batching in ordered iteration.""" + url, _ = dev_tar + ds = Dataset[DevSample](url) + mapped = ds.map(lambda s: DevSample(name=s.name, value=s.value * 10)) + batches = list(mapped.ordered(batch_size=5)) + all_samples = [s for batch in batches for s in batch.samples] + assert len(all_samples) == 20 + assert all_samples[0].value == 0 + assert all_samples[1].value == 10 + + def test_filter_and_map_with_batched(self, dev_tar): + """filter() + map() both applied before batching.""" + url, _ = dev_tar + ds = Dataset[DevSample](url) + result = ds.filter(lambda s: s.value >= 10).map( + lambda s: DevSample(name=s.name, value=s.value * 2) + ) + batches = list(result.ordered(batch_size=5)) + all_samples = [s for batch in batches for s in batch.samples] + assert len(all_samples) == 10 + assert all(s.value >= 20 for s in all_samples) + class TestDatasetSelect: def test_select(self, dev_tar): From da530914e781d54110471d41c09cef30d9bfd09d Mon Sep 17 00:00:00 2001 From: Maxine Levesque <220467675+maxine-at-forecast@users.noreply.github.com> Date: Wed, 1 Apr 2026 13:08:22 -0700 Subject: [PATCH 2/2] style: fix ruff format in _build_pipeline Co-Authored-By: Claude Opus 4.6 --- src/atdata/dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/atdata/dataset.py b/src/atdata/dataset.py index 21aca2b..fd4395d 100644 --- a/src/atdata/dataset.py +++ b/src/atdata/dataset.py @@ -1137,9 +1137,7 @@ def _build_pipeline( # not raw WDS dicts — default collation expects dicts. stages.append(wds.filters.batched(batch_size, collation_fn=None)) stages.append( - wds.filters.map( - lambda samples: SampleBatch[self.sample_type](samples) - ) + wds.filters.map(lambda samples: SampleBatch[self.sample_type](samples)) ) return wds.pipeline.DataPipeline(*stages)