Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 58 additions & 43 deletions src/atdata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,59 @@ def _post_wrap_stages(self) -> list:
stages.append(wds.filters.map(map_fn))
return stages

def _build_pipeline(
self,
*,
shuffle_shards: int | None = None,
shuffle_samples: int | None = None,
batch_size: int | None = None,
) -> wds.pipeline.DataPipeline:
"""Build the WebDataset pipeline with all configured stages.

This is the single source of truth for pipeline construction.
``ordered()`` and ``shuffled()`` delegate here instead of
duplicating the stage list.

Args:
shuffle_shards: If set, shuffle shards with this buffer size.
shuffle_samples: If set, shuffle samples with this buffer size.
batch_size: If set, batch samples into groups of this size.

Returns:
A fully-configured ``DataPipeline``.
"""
stages: list = [_ShardListStage(self._source)]

if shuffle_shards is not None:
stages.append(wds.filters.shuffle(shuffle_shards))

stages += [
wds.shardlists.split_by_worker,
_StreamOpenerStage(self._source),
wds.tariterators.tar_file_expander,
wds.tariterators.group_by_keys,
]

if shuffle_samples is not None:
stages.append(wds.filters.shuffle(shuffle_samples))

# Wrap raw WDS dicts into typed samples, then apply filter/map.
# Filter/map are always applied per-sample, before any batching.
stages.append(wds.filters.map(self.wrap))
stages += self._post_wrap_stages()

if batch_size is not None:
if batch_size < 1:
raise ValueError(f"batch_size must be >= 1, got {batch_size}")
# collation_fn=None because items are already typed samples,
# not raw WDS dicts — default collation expects dicts.
stages.append(wds.filters.batched(batch_size, collation_fn=None))
stages.append(
wds.filters.map(lambda samples: SampleBatch[self.sample_type](samples))
)

return wds.pipeline.DataPipeline(*stages)

@overload
def ordered(
self,
Expand Down Expand Up @@ -1124,26 +1177,7 @@ def ordered(
>>> for batch in ds.ordered(batch_size=32):
... process(batch) # batch is SampleBatch[ST]
"""
if batch_size is None:
return wds.pipeline.DataPipeline(
_ShardListStage(self._source),
wds.shardlists.split_by_worker,
_StreamOpenerStage(self._source),
wds.tariterators.tar_file_expander,
wds.tariterators.group_by_keys,
wds.filters.map(self.wrap),
*self._post_wrap_stages(),
)

return wds.pipeline.DataPipeline(
_ShardListStage(self._source),
wds.shardlists.split_by_worker,
_StreamOpenerStage(self._source),
wds.tariterators.tar_file_expander,
wds.tariterators.group_by_keys,
wds.filters.batched(batch_size),
wds.filters.map(self.wrap_batch),
)
return self._build_pipeline(batch_size=batch_size)

@overload
def shuffled(
Expand Down Expand Up @@ -1193,29 +1227,10 @@ def shuffled(
>>> for batch in ds.shuffled(batch_size=32):
... process(batch) # batch is SampleBatch[ST]
"""
if batch_size is None:
return wds.pipeline.DataPipeline(
_ShardListStage(self._source),
wds.filters.shuffle(buffer_shards),
wds.shardlists.split_by_worker,
_StreamOpenerStage(self._source),
wds.tariterators.tar_file_expander,
wds.tariterators.group_by_keys,
wds.filters.shuffle(buffer_samples),
wds.filters.map(self.wrap),
*self._post_wrap_stages(),
)

return wds.pipeline.DataPipeline(
_ShardListStage(self._source),
wds.filters.shuffle(buffer_shards),
wds.shardlists.split_by_worker,
_StreamOpenerStage(self._source),
wds.tariterators.tar_file_expander,
wds.tariterators.group_by_keys,
wds.filters.shuffle(buffer_samples),
wds.filters.batched(batch_size),
wds.filters.map(self.wrap_batch),
return self._build_pipeline(
shuffle_shards=buffer_shards,
shuffle_samples=buffer_samples,
batch_size=batch_size,
)

# Design note: Uses pandas for parquet export. Could be replaced with
Expand Down
7 changes: 3 additions & 4 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,9 @@ class BatchSizeSample:

dataset = atdata.Dataset[BatchSizeSample](wds_filename)

# batch_size=0 produces empty batches, causing IndexError in webdataset's
# batched() when it tries to inspect the first element of an empty group
with pytest.raises(IndexError):
list(dataset.ordered(batch_size=0))
# batch_size=0 is invalid — _build_pipeline validates eagerly
with pytest.raises(ValueError, match="batch_size must be >= 1"):
dataset.ordered(batch_size=0)


##
Expand Down
43 changes: 43 additions & 0 deletions tests/test_dev_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,26 @@ def test_filter_chain(self, dev_tar):
assert len(result) == 5
assert all(10 <= s.value < 15 for s in result)

def test_filter_with_batched_ordered(self, dev_tar):
"""filter() is applied before batching in ordered iteration."""
url, _ = dev_tar
ds = Dataset[DevSample](url)
filtered = ds.filter(lambda s: s.value >= 15)
batches = list(filtered.ordered(batch_size=3))
all_samples = [s for batch in batches for s in batch.samples]
assert len(all_samples) == 5
assert all(s.value >= 15 for s in all_samples)

def test_filter_with_batched_shuffled(self, dev_tar):
"""filter() is applied before batching in shuffled iteration."""
url, _ = dev_tar
ds = Dataset[DevSample](url)
filtered = ds.filter(lambda s: s.value >= 15)
batches = list(filtered.shuffled(batch_size=3))
all_samples = [s for batch in batches for s in batch.samples]
assert len(all_samples) == 5
assert all(s.value >= 15 for s in all_samples)


class TestDatasetMap:
def test_map(self, dev_tar):
Expand All @@ -290,6 +310,29 @@ def test_map(self, dev_tar):
assert all(isinstance(r, str) for r in result)
assert result[0] == "s000"

def test_map_with_batched_ordered(self, dev_tar):
"""map() is applied before batching in ordered iteration."""
url, _ = dev_tar
ds = Dataset[DevSample](url)
mapped = ds.map(lambda s: DevSample(name=s.name, value=s.value * 10))
batches = list(mapped.ordered(batch_size=5))
all_samples = [s for batch in batches for s in batch.samples]
assert len(all_samples) == 20
assert all_samples[0].value == 0
assert all_samples[1].value == 10

def test_filter_and_map_with_batched(self, dev_tar):
"""filter() + map() both applied before batching."""
url, _ = dev_tar
ds = Dataset[DevSample](url)
result = ds.filter(lambda s: s.value >= 10).map(
lambda s: DevSample(name=s.name, value=s.value * 2)
)
batches = list(result.ordered(batch_size=5))
all_samples = [s for batch in batches for s in batch.samples]
assert len(all_samples) == 10
assert all(s.value >= 20 for s in all_samples)


class TestDatasetSelect:
def test_select(self, dev_tar):
Expand Down
Loading