diff --git a/.chainlink/issues.db b/.chainlink/issues.db index 12b36ae..edbde07 100644 Binary files a/.chainlink/issues.db and b/.chainlink/issues.db differ diff --git a/.github/workflows/uv-test.yml b/.github/workflows/uv-test.yml index cc8b74d..2e18832 100644 --- a/.github/workflows/uv-test.yml +++ b/.github/workflows/uv-test.yml @@ -4,13 +4,11 @@ on: push: branches: - main - - release/* pull_request: - branches: - - main permissions: contents: read + actions: read concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -77,3 +75,60 @@ jobs: with: fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} + + benchmark: + name: Benchmarks + runs-on: ubuntu-latest + needs: [lint] + permissions: + contents: write + actions: write + steps: + - uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.14" + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Install just + uses: extractions/setup-just@v2 + + - name: Install the project + run: uv sync --locked --all-extras --dev + + - name: Start Redis + uses: supercharge/redis-github-action@1.8.1 + with: + redis-version: 7 + + - name: Run benchmarks + run: just bench + + - name: Copy report to docs + run: | + mkdir -p docs/benchmarks + cp .bench/report.html docs/benchmarks/index.html + + - name: Commit updated benchmark docs + if: github.event_name == 'push' + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add docs/benchmarks/index.html + git diff --cached --quiet || git commit -m "docs: update benchmark report [skip ci]" + git push + + - name: Upload benchmark report + uses: actions/upload-artifact@v4 + if: always() + with: + name: benchmark-report + path: | + .bench/report.html + .bench/*.json diff --git a/.gitignore b/.gitignore index 75ac226..97301a8 100644 --- a/.gitignore +++ b/.gitignore @@ -52,6 +52,9 @@ MANIFEST pip-log.txt pip-delete-this-directory.txt +# Benchmark results +.bench/ + # Unit test / coverage reports htmlcov/ .tox/ diff --git a/.vscode/settings.json b/.vscode/settings.json index 3e03efa..56064c5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,15 +5,19 @@ "atproto", "creds", "dtype", + "fastparquet", "getattr", "hgetall", "hset", + "libipld", "maxcount", "minioadmin", "msgpack", "ndarray", "NSID", "ormsgpack", + "psycopg", + "pydantic", "pypi", "pyproject", "pytest", @@ -24,6 +28,8 @@ "schemamodels", "shardlists", "tariterators", + "tqdm", + "typer", "unpackb", "webdataset" ], diff --git a/CHANGELOG.md b/CHANGELOG.md index feeea7e..312972d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,111 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - **Comprehensive integration test suite**: 593 tests covering E2E flows, error handling, edge cases ### Changed +- Investigate upload-artifact not finding benchmark output (#512) +- Fix duplicate CI runs for push+PR overlap (#511) +- Scope contents:write permission to benchmark job only (#510) +- Add benchmark docs auto-commit to CI workflow (#509) +- Submit PR for v0.3.0b1 release to upstream/main (#508) +- Implement GH#39: Production hardening (observability, error handling, testing infra) (#504) +- Add pluggable structured logging via atdata.configure_logging (#507) +- Add PartialFailureError and shard-level error handling to Dataset.map (#506) +- Add atdata.testing module with mock clients, fixtures, and helpers (#505) +- Fix CI linting failures (20 ruff errors) (#503) +- Adversarial review: Post-benchmark suite assessment (#494) +- Remove redundant protocol docstrings that restate signatures (#500) +- Add missing unit tests for _type_utils.py (#499) +- Strengthen weak assertions (assert X is not None → value checks) (#498) +- Trim verbose exception constructor docstrings (#501) +- Analyze benchmark results for performance improvement opportunities (#502) +- Consolidate remaining duplicate sample types in test files (#497) +- Remove dead code: _repo_legacy.py legacy UUID field, unused imports (#496) +- Trim verbose docstrings in dataset.py and _index.py (#495) +- Benchmark report: replace mean/stddev with median/IQR, add per-sample columns (#492) +- Add parameter descriptions to benchmark suite with automatic report introspection (#491) +- HTML benchmark reports with CI integration (#487) +- Add bench + render step to CI on highest Python version only (#490) +- Update justfile bench commands to export JSON and render (#489) +- Create render_report.py script to convert JSON to HTML (#488) +- Increase test coverage for low-coverage modules (#480) +- Add providers/_postgres.py tests (mock-based) (#485) +- Add _stub_manager.py tests (#484) +- Add manifest/_query.py tests (#483) +- Add repository.py tests (#482) +- Add CLI tests (cli/__init__, diagnose, local, preview, schema) (#481) +- Check test coverage for CLI utils (#479) +- Add performance benchmark suite for atdata (#471) +- Verify benchmarks run (#478) +- Update pyproject.toml and justfile (#477) +- Create bench_atmosphere.py (#476) +- Create bench_query.py (#475) +- Create bench_dataset_io.py (#474) +- Create bench_index_providers.py (#473) +- Create benchmarks/conftest.py with shared fixtures (#472) +- Add per-shard manifest and query system (GH #35) (#462) +- Write unit and integration tests (#470) +- Integrate manifest into write path and Dataset.query() (#469) +- Implement QueryExecutor and SampleLocation (#468) +- Implement ManifestWriter (JSON + parquet) (#467) +- Implement ManifestBuilder (#465) +- Implement ShardManifest data model (#466) +- Implement aggregate collectors (categorical, numeric, set) (#464) +- Implement ManifestField annotation and resolve_manifest_fields() (#463) +- Migrate type annotations from PackableSample to Packable protocol (#461) +- Remove LocalIndex factory — consolidate to Index (#460) +- Split local.py monolith into local/ package (#452) +- Verify tests and lint pass (#459) +- Create __init__.py re-export facade and delete local.py (#458) +- Create _repo_legacy.py with deprecated Repo class (#457) +- Create _index.py with Index class and LocalIndex factory (#456) +- Create _s3.py with S3DataStore and S3 helpers (#455) +- Create _schema.py with schema models and helpers (#454) +- Create _entry.py with LocalDatasetEntry and constants (#453) +- Migrate CLI from argparse to typer (#449) +- Investigate test failures (#450) +- Fix ensure_stub receiving LocalSchemaRecord instead of dict (#451) +- GH#38: Developer experience improvements (#437) +- CLI: atdata preview command (#440) +- CLI: atdata schema show/diff commands (#439) +- CLI: atdata inspect command (#438) +- Dataset.__len__ and Dataset.select() for sample count and indexed access (#447) +- Dataset.to_pandas() and Dataset.to_dict() export methods (#446) +- Dataset.filter() and Dataset.map() streaming transforms (#445) +- Dataset.get(key) for keyed sample access (#442) +- Dataset.describe() summary statistics (#444) +- Dataset.schema property and column_names (#443) +- Dataset.head(n) and Dataset.__iter__ convenience methods (#441) +- Custom exception hierarchy with actionable error messages (#448) +- Adversarial review: Post-Repository consolidation assessment (#430) +- Remove backwards-compat dict-access methods from SchemaField and LocalSchemaRecord (#436) +- Add missing test coverage for Repository prefix routing edge cases and error paths (#435) +- Trim over-verbose docstrings in local.py module/class level (#434) +- Fix formally incorrect test assertions (batch_size, CID, brace notation) (#433) +- Consolidate duplicate test sample types across test files into conftest.py (#432) +- Consolidate duplicate entry-creation logic in Index (add_entry vs _insert_dataset_to_provider) (#431) +- Switch default Index provider from Redis to SQLite (#429) +- Consolidated Index with Repository system (#424) +- Phase 4: Deprecate AtmosphereIndex, update exports (#428) +- Phase 3: Default Index singleton and load_dataset integration (#427) +- Phase 2: Extend Index with repos/atmosphere params and prefix routing (#426) +- Phase 1: Create Repository dataclass and _AtmosphereBackend in repository.py (#425) +- Adversarial review: Post-IndexProvider pluggable storage assessment (#417) +- Convert TODO comments to tracked issues or remove (#422) +- Remove deprecated shard_list property references from docstrings (#421) +- Replace bare except in _stub_manager.py and cli/local.py with specific exceptions (#423) +- Tighten generic pytest.raises(Exception) to specific exception types in tests (#420) +- Replace assert statements with ValueError in production code (#419) +- Consolidate duplicated _parse_semver into _type_utils.py (#418) +- feat: Add SQLite/PostgreSQL index providers (GH #42) (#409) +- Update documentation and public API exports (#416) +- Add tests for all providers (#415) +- Refactor Index class to accept provider parameter (#414) +- Implement PostgresIndexProvider (#413) +- Implement SqliteIndexProvider (#412) +- Implement RedisIndexProvider (extract from Index class) (#411) +- Define IndexProvider protocol in _protocols.py (#410) +- Add just lint command to justfile (#408) +- Add SQLite/PostgreSQL providers for LocalIndex (in addition to Redis) (#407) +- Fix type hints for @atdata.packable decorator to show PackableSample methods (#406) - Review GitHub workflows and recommend CI improvements (#405) - Fix type signatures for Dataset.ordered and Dataset.shuffled (GH#28) (#404) - Investigate quartodoc Example section rendering - missing CSS classes on pre/code tags (#401) diff --git a/CLAUDE.md b/CLAUDE.md index 349c268..6b096b4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -46,8 +46,10 @@ uv build Development tasks are managed with [just](https://github.com/casey/just), a command runner. Available commands: ```bash -# Build documentation (runs quartodoc + quarto) -just docs +just test # Run all tests with coverage +just test tests/test_dataset.py # Run specific test file +just lint # Run ruff check + format check +just docs # Build documentation (runs quartodoc + quarto) ``` The `justfile` is in the project root. Add new dev tasks there rather than creating shell scripts. diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/bench_atmosphere.py b/benchmarks/bench_atmosphere.py new file mode 100644 index 0000000..cab470e --- /dev/null +++ b/benchmarks/bench_atmosphere.py @@ -0,0 +1,220 @@ +"""Performance benchmarks for remote storage backends. + +Covers S3DataStore (via moto mock) and Atmosphere/ATProto (network-gated). +S3 benchmarks use moto for reproducible local measurement. +Atmosphere benchmarks are marked ``network`` and skip unless a live PDS is available. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +from moto import mock_aws + +import atdata + +from .conftest import ( + IMAGE_SHAPE, + BenchBasicSample, + BenchManifestSample, + BenchNumpySample, + generate_basic_samples, + generate_manifest_samples, + generate_numpy_samples, + write_tar, +) + + +# ============================================================================= +# S3 Fixtures +# ============================================================================= + + +@pytest.fixture +def mock_s3(): + """Provide mock S3 environment using moto.""" + with mock_aws(): + import boto3 + + creds = { + "AWS_ACCESS_KEY_ID": "testing", + "AWS_SECRET_ACCESS_KEY": "testing", + } + s3_client = boto3.client( + "s3", + aws_access_key_id=creds["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"], + region_name="us-east-1", + ) + bucket_name = "bench-bucket" + s3_client.create_bucket(Bucket=bucket_name) + yield { + "credentials": creds, + "bucket": bucket_name, + } + + +def _make_s3_store(mock_s3_env): + from atdata.local._s3 import S3DataStore + + return S3DataStore( + credentials=mock_s3_env["credentials"], + bucket=mock_s3_env["bucket"], + ) + + +def _make_source_dataset(tmp_path, samples): + """Create a local dataset from samples for use as S3 write source.""" + tar_path = write_tar(tmp_path / "source-000000.tar", samples) + sample_type = type(samples[0]) + return atdata.Dataset[sample_type](url=str(tar_path)) + + +# ============================================================================= +# S3 Write Benchmarks +# ============================================================================= + + +@pytest.mark.bench_s3 +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") +class TestS3WriteBenchmarks: + """S3 shard writing benchmarks via moto mock.""" + + PARAM_LABELS = {"n": "samples per shard"} + + @pytest.mark.parametrize("n", [100, 500], ids=["100", "500"]) + def test_s3_write_shards(self, benchmark, tmp_path, mock_s3, n): + benchmark.extra_info["n_samples"] = n + samples = generate_basic_samples(n) + ds = _make_source_dataset(tmp_path, samples) + store = _make_s3_store(mock_s3) + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + store.write_shards(ds, prefix=f"bench/basic-{n}-{idx}") + + benchmark(_write) + + def test_s3_write_with_manifest(self, benchmark, tmp_path, mock_s3): + benchmark.extra_info["n_samples"] = 200 + samples = generate_manifest_samples(200) + ds = _make_source_dataset(tmp_path, samples) + store = _make_s3_store(mock_s3) + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + store.write_shards( + ds, prefix=f"bench/manifest-{idx}", manifest=True, + cache_local=True, + ) + + benchmark(_write) + + def test_s3_write_cache_local(self, benchmark, tmp_path, mock_s3): + benchmark.extra_info["n_samples"] = 200 + samples = generate_basic_samples(200) + ds = _make_source_dataset(tmp_path, samples) + store = _make_s3_store(mock_s3) + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + store.write_shards( + ds, prefix=f"bench/cache-{idx}", cache_local=True + ) + + benchmark(_write) + + def test_s3_write_direct(self, benchmark, tmp_path, mock_s3): + benchmark.extra_info["n_samples"] = 200 + samples = generate_basic_samples(200) + ds = _make_source_dataset(tmp_path, samples) + store = _make_s3_store(mock_s3) + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + store.write_shards( + ds, prefix=f"bench/direct-{idx}", cache_local=False + ) + + benchmark(_write) + + def test_s3_write_numpy(self, benchmark, tmp_path, mock_s3): + benchmark.extra_info["n_samples"] = 100 + samples = generate_numpy_samples(100) + ds = _make_source_dataset(tmp_path, samples) + store = _make_s3_store(mock_s3) + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + store.write_shards(ds, prefix=f"bench/numpy-{idx}") + + benchmark(_write) + + +# ============================================================================= +# Atmosphere Benchmarks (network-gated) +# ============================================================================= + + +@pytest.mark.network +class TestAtmosphereBenchmarks: + """Atmosphere/ATProto benchmarks. Require live PDS access. + + Run with: just bench -m network + """ + + def test_atmosphere_publish_dataset(self, benchmark, tmp_path): + """End-to-end dataset publish to Atmosphere.""" + import os + + handle = os.environ.get("ATDATA_BENCH_ATP_HANDLE") + password = os.environ.get("ATDATA_BENCH_ATP_PASSWORD") + if not handle or not password: + pytest.skip("ATDATA_BENCH_ATP_HANDLE/PASSWORD not set") + + from atdata.atmosphere.client import AtmosphereClient + + client = AtmosphereClient(handle=handle, password=password) + + samples = generate_basic_samples(10) + tar_path = write_tar(tmp_path / "atmo-000000.tar", samples) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + + counter = [0] + + def _publish(): + idx = counter[0] + counter[0] += 1 + from atdata.atmosphere.records import DatasetPublisher + + publisher = DatasetPublisher(client) + publisher.publish(ds, name=f"bench-atmo-{idx}") + + benchmark(_publish) + + def test_atmosphere_resolve_dataset(self, benchmark): + """Resolve a dataset record from Atmosphere (read-only, anonymous).""" + import os + + ref = os.environ.get("ATDATA_BENCH_ATP_DATASET_REF") + if not ref: + pytest.skip("ATDATA_BENCH_ATP_DATASET_REF not set") + + from atdata.local._index import Index + + index = Index() + + benchmark(index.get_dataset, ref) diff --git a/benchmarks/bench_dataset_io.py b/benchmarks/bench_dataset_io.py new file mode 100644 index 0000000..8668b09 --- /dev/null +++ b/benchmarks/bench_dataset_io.py @@ -0,0 +1,293 @@ +"""Performance benchmarks for dataset read/write operations. + +Measures shard writing throughput, read iteration speed, serialization +overhead, and round-trip performance for basic and numpy sample types. +""" + +from __future__ import annotations + +import numpy as np +import pytest +import webdataset as wds + +import atdata + +from .conftest import ( + IMAGE_DTYPE, + IMAGE_SHAPE, + TSERIES_DTYPE, + TSERIES_SHAPE, + BenchBasicSample, + BenchManifestSample, + BenchNumpySample, + generate_basic_samples, + generate_manifest_samples, + generate_numpy_samples, + write_tar, + write_tar_with_manifest, +) + + +# ============================================================================= +# Write Benchmarks +# ============================================================================= + + +@pytest.mark.bench_io +class TestShardWriteBenchmarks: + """Shard writing throughput benchmarks.""" + + PARAM_LABELS = {"n": "samples per shard"} + + @pytest.mark.parametrize("n", [100, 1000, 10000], ids=["100", "1k", "10k"]) + def test_write_basic_shard(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_basic_samples(n) + + def _write(): + tar_path = tmp_path / f"basic-{n}.tar" + write_tar(tar_path, samples) + + benchmark(_write) + + @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"]) + def test_write_numpy_shard(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_numpy_samples(n) + + def _write(): + tar_path = tmp_path / f"numpy-{n}.tar" + write_tar(tar_path, samples) + + benchmark(_write) + + def test_write_large_numpy_shard(self, benchmark, tmp_path): + benchmark.extra_info["n_samples"] = 10 + samples = generate_numpy_samples(10, shape=TSERIES_SHAPE, dtype=TSERIES_DTYPE) + + def _write(): + tar_path = tmp_path / "numpy-large.tar" + write_tar(tar_path, samples) + + benchmark(_write) + + @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"]) + def test_write_with_manifest(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_manifest_samples(n) + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + tar_path = tmp_path / f"manifest-{n}-{idx}.tar" + write_tar_with_manifest(tar_path, samples, BenchManifestSample) + + benchmark(_write) + + def test_write_multi_shard(self, benchmark, tmp_path): + benchmark.extra_info["n_samples"] = 10000 + samples = generate_basic_samples(10000) + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + out_dir = tmp_path / f"multi-{idx}" + out_dir.mkdir(exist_ok=True) + pattern = str(out_dir / "shard-%06d.tar") + with wds.writer.ShardWriter(pattern, maxcount=1000) as sink: + for sample in samples: + sink.write(sample.as_wds) + + benchmark(_write) + + +# ============================================================================= +# Read Benchmarks +# ============================================================================= + + +@pytest.mark.bench_io +class TestShardReadBenchmarks: + """Shard reading and iteration benchmarks.""" + + PARAM_LABELS = {"n": "samples in dataset", "batch_size": "samples per batch"} + + @pytest.mark.parametrize("n", [100, 1000, 10000], ids=["100", "1k", "10k"]) + def test_read_ordered(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_basic_samples(n) + tar_path = write_tar(tmp_path / f"read-ordered-{n}.tar", samples) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + + def _read(): + count = 0 + for _ in ds.ordered(): + count += 1 + return count + + result = benchmark(_read) + assert result == n + + @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"]) + def test_read_shuffled(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_basic_samples(n) + tar_path = write_tar(tmp_path / f"read-shuffled-{n}.tar", samples) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + + def _read(): + count = 0 + for _ in ds.shuffled(): + count += 1 + return count + + result = benchmark(_read) + assert result == n + + @pytest.mark.parametrize( + "batch_size", [32, 128], ids=["batch32", "batch128"] + ) + def test_read_batched(self, benchmark, tmp_path, batch_size): + n = 1000 + benchmark.extra_info["n_samples"] = n + samples = generate_basic_samples(n) + tar_path = write_tar(tmp_path / f"read-batched-{batch_size}.tar", samples) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + + def _read(): + count = 0 + for batch in ds.ordered(batch_size=batch_size): + count += 1 + return count + + benchmark(_read) + + @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"]) + def test_read_numpy_ordered(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_numpy_samples(n) + tar_path = write_tar(tmp_path / f"read-numpy-{n}.tar", samples) + ds = atdata.Dataset[BenchNumpySample](url=str(tar_path)) + + def _read(): + count = 0 + for _ in ds.ordered(): + count += 1 + return count + + result = benchmark(_read) + assert result == n + + +# ============================================================================= +# Serialization Benchmarks (No I/O) +# ============================================================================= + + +@pytest.mark.bench_serial +class TestSerializationBenchmarks: + """Pure serialization/deserialization without disk I/O.""" + + def test_serialize_basic_sample(self, benchmark): + sample = BenchBasicSample(name="bench_sample", value=42) + benchmark(lambda: sample.packed) + + def test_deserialize_basic_sample(self, benchmark): + sample = BenchBasicSample(name="bench_sample", value=42) + packed = sample.packed + benchmark(BenchBasicSample.from_bytes, packed) + + def test_serialize_numpy_sample(self, benchmark): + sample = BenchNumpySample( + data=np.random.randint(0, 256, size=IMAGE_SHAPE, dtype=IMAGE_DTYPE), + label="bench", + ) + benchmark(lambda: sample.packed) + + def test_deserialize_numpy_sample(self, benchmark): + sample = BenchNumpySample( + data=np.random.randint(0, 256, size=IMAGE_SHAPE, dtype=IMAGE_DTYPE), + label="bench", + ) + packed = sample.packed + benchmark(BenchNumpySample.from_bytes, packed) + + def test_serialize_large_numpy(self, benchmark): + sample = BenchNumpySample( + data=np.random.randn(*TSERIES_SHAPE).astype(TSERIES_DTYPE), + label="large", + ) + benchmark(lambda: sample.packed) + + def test_deserialize_large_numpy(self, benchmark): + sample = BenchNumpySample( + data=np.random.randn(*TSERIES_SHAPE).astype(TSERIES_DTYPE), + label="large", + ) + packed = sample.packed + benchmark(BenchNumpySample.from_bytes, packed) + + def test_as_wds_basic(self, benchmark): + sample = BenchBasicSample(name="bench_sample", value=42) + benchmark(lambda: sample.as_wds) + + def test_as_wds_numpy(self, benchmark): + sample = BenchNumpySample( + data=np.random.randint(0, 256, size=IMAGE_SHAPE, dtype=IMAGE_DTYPE), + label="bench", + ) + benchmark(lambda: sample.as_wds) + + +# ============================================================================= +# Round-Trip Benchmarks +# ============================================================================= + + +@pytest.mark.bench_io +class TestRoundTripBenchmarks: + """Full write-then-read round-trip benchmarks.""" + + PARAM_LABELS = {"n": "samples round-tripped"} + + @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"]) + def test_roundtrip_basic(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_basic_samples(n) + counter = [0] + + def _roundtrip(): + idx = counter[0] + counter[0] += 1 + tar_path = tmp_path / f"rt-basic-{n}-{idx}.tar" + write_tar(tar_path, samples) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + count = 0 + for _ in ds.ordered(): + count += 1 + return count + + result = benchmark(_roundtrip) + assert result == n + + @pytest.mark.parametrize("n", [100, 500], ids=["100", "500"]) + def test_roundtrip_numpy(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_numpy_samples(n) + counter = [0] + + def _roundtrip(): + idx = counter[0] + counter[0] += 1 + tar_path = tmp_path / f"rt-numpy-{n}-{idx}.tar" + write_tar(tar_path, samples) + ds = atdata.Dataset[BenchNumpySample](url=str(tar_path)) + count = 0 + for _ in ds.ordered(): + count += 1 + return count + + result = benchmark(_roundtrip) + assert result == n diff --git a/benchmarks/bench_index_providers.py b/benchmarks/bench_index_providers.py new file mode 100644 index 0000000..4d0aeea --- /dev/null +++ b/benchmarks/bench_index_providers.py @@ -0,0 +1,215 @@ +"""Performance benchmarks for index provider operations. + +Measures read/write latency and throughput for SQLite, Redis, and PostgreSQL +providers. Each benchmark skips gracefully if the backend is unavailable. +""" + +from __future__ import annotations + +import pytest + +from atdata.local._entry import LocalDatasetEntry + +from .conftest import BenchBasicSample, generate_basic_samples + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _make_entry(i: int) -> LocalDatasetEntry: + return LocalDatasetEntry( + name=f"bench_dataset_{i:06d}", + schema_ref=f"atdata://local/sampleSchema/BenchBasicSample@1.0.0", + data_urls=[f"/tmp/bench/data-{i:06d}.tar"], + metadata={"index": i, "split": "train"}, + ) + + +def _make_schema_json(name: str, version: str) -> str: + return ( + f'{{"name": "{name}", "version": "{version}", ' + f'"fields": [{{"name": "x", "type": "int"}}]}}' + ) + + +def _prepopulate_entries(provider, n: int) -> list[LocalDatasetEntry]: + entries = [_make_entry(i) for i in range(n)] + for entry in entries: + provider.store_entry(entry) + return entries + + +def _prepopulate_schemas(provider, name: str, n: int) -> list[str]: + versions = [] + for i in range(n): + version = f"1.0.{i}" + provider.store_schema(name, version, _make_schema_json(name, version)) + versions.append(version) + return versions + + +# ============================================================================= +# Write Benchmarks +# ============================================================================= + + +@pytest.mark.bench_index +class TestProviderWriteBenchmarks: + """Write operation benchmarks across all providers.""" + + PARAM_LABELS = {"n": "entries to store", "any_provider": "storage backend"} + + def test_store_single_entry(self, benchmark, any_provider): + entry = _make_entry(0) + benchmark(any_provider.store_entry, entry) + + @pytest.mark.parametrize("n", [10, 100, 1000], ids=["10", "100", "1k"]) + def test_store_entries_bulk(self, benchmark, any_provider, n): + entries = [_make_entry(i) for i in range(n)] + + def _store_all(): + for entry in entries: + any_provider.store_entry(entry) + + benchmark(_store_all) + + def test_store_schema(self, benchmark, any_provider): + benchmark( + any_provider.store_schema, + "BenchSample", + "1.0.0", + _make_schema_json("BenchSample", "1.0.0"), + ) + + @pytest.mark.parametrize("n", [10, 50], ids=["10v", "50v"]) + def test_store_schema_versions(self, benchmark, any_provider, n): + def _store_versions(): + for i in range(n): + v = f"1.0.{i}" + any_provider.store_schema( + "BenchVersioned", v, _make_schema_json("BenchVersioned", v) + ) + + benchmark(_store_versions) + + +# ============================================================================= +# Read Benchmarks +# ============================================================================= + + +@pytest.mark.bench_index +class TestProviderReadBenchmarks: + """Read operation benchmarks across all providers.""" + + PARAM_LABELS = {"n": "entries in index", "any_provider": "storage backend"} + + def test_get_entry_by_name(self, benchmark, any_provider): + entries = _prepopulate_entries(any_provider, 100) + target = entries[50] + benchmark(any_provider.get_entry_by_name, target.name) + + def test_get_entry_by_cid(self, benchmark, any_provider): + entries = _prepopulate_entries(any_provider, 100) + target = entries[50] + benchmark(any_provider.get_entry_by_cid, target.cid) + + @pytest.mark.parametrize("n", [10, 100, 1000], ids=["10", "100", "1k"]) + def test_iter_entries(self, benchmark, any_provider, n): + _prepopulate_entries(any_provider, n) + benchmark(lambda: list(any_provider.iter_entries())) + + def test_get_schema_json(self, benchmark, any_provider): + any_provider.store_schema( + "BenchRead", "1.0.0", _make_schema_json("BenchRead", "1.0.0") + ) + benchmark(any_provider.get_schema_json, "BenchRead", "1.0.0") + + @pytest.mark.parametrize("n", [5, 20, 50], ids=["5v", "20v", "50v"]) + def test_find_latest_version(self, benchmark, any_provider, n): + _prepopulate_schemas(any_provider, "BenchLatest", n) + benchmark(any_provider.find_latest_version, "BenchLatest") + + def test_iter_schemas(self, benchmark, any_provider): + _prepopulate_schemas(any_provider, "BenchIterSchema", 20) + benchmark(lambda: list(any_provider.iter_schemas())) + + +# ============================================================================= +# Index-Level Benchmarks +# ============================================================================= + + +@pytest.mark.bench_index +class TestIndexBenchmarks: + """Benchmarks through the full Index API.""" + + def test_index_insert_dataset(self, benchmark, tmp_path, sqlite_provider): + import atdata + from atdata.local._index import Index + + samples = generate_basic_samples(10) + tar_path = tmp_path / "idx-bench-000000.tar" + from .conftest import write_tar + + write_tar(tar_path, samples) + + index = Index(provider=sqlite_provider, atmosphere=None) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + + counter = [0] + + def _insert(): + name = f"bench_ds_{counter[0]:06d}" + counter[0] += 1 + index.insert_dataset(ds, name=name) + + benchmark(_insert) + + def test_index_get_dataset(self, benchmark, tmp_path, sqlite_provider): + import atdata + from atdata.local._index import Index + + samples = generate_basic_samples(10) + tar_path = tmp_path / "idx-get-000000.tar" + from .conftest import write_tar + + write_tar(tar_path, samples) + + index = Index(provider=sqlite_provider, atmosphere=None) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + index.insert_dataset(ds, name="bench_lookup_target") + + benchmark(index.get_dataset, "bench_lookup_target") + + def test_index_list_datasets(self, benchmark, tmp_path, sqlite_provider): + import atdata + from atdata.local._index import Index + + samples = generate_basic_samples(5) + tar_path = tmp_path / "idx-list-000000.tar" + from .conftest import write_tar + + write_tar(tar_path, samples) + + index = Index(provider=sqlite_provider, atmosphere=None) + ds = atdata.Dataset[BenchBasicSample](url=str(tar_path)) + for i in range(100): + index.insert_dataset(ds, name=f"bench_list_{i:04d}") + + benchmark(index.list_datasets) + + def test_index_publish_schema(self, benchmark, sqlite_provider): + from atdata.local._index import Index + + index = Index(provider=sqlite_provider, atmosphere=None) + counter = [0] + + def _publish(): + v = f"1.0.{counter[0]}" + counter[0] += 1 + index.publish_schema(BenchBasicSample, version=v) + + benchmark(_publish) diff --git a/benchmarks/bench_query.py b/benchmarks/bench_query.py new file mode 100644 index 0000000..a5eccec --- /dev/null +++ b/benchmarks/bench_query.py @@ -0,0 +1,278 @@ +"""Performance benchmarks for the manifest query system. + +Measures query execution speed across different predicate types, +manifest loading performance, and scaling behavior with increasing +shard counts and sample sizes. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from atdata.manifest import ManifestBuilder, ManifestWriter, QueryExecutor, ShardManifest + +from .conftest import ( + BenchManifestSample, + generate_manifest_samples, + create_sharded_dataset, + write_tar_with_manifest, +) + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def query_dataset_small(tmp_path): + """Small query dataset: 2 shards x 50 samples = 100 total.""" + samples = generate_manifest_samples(100) + create_sharded_dataset( + tmp_path, samples, 50, BenchManifestSample, with_manifests=True + ) + executor = QueryExecutor.from_directory(tmp_path) + return executor, tmp_path + + +@pytest.fixture +def query_dataset_medium(tmp_path): + """Medium query dataset: 10 shards x 100 samples = 1000 total.""" + samples = generate_manifest_samples(1000) + create_sharded_dataset( + tmp_path, samples, 100, BenchManifestSample, with_manifests=True + ) + executor = QueryExecutor.from_directory(tmp_path) + return executor, tmp_path + + +@pytest.fixture +def query_dataset_large(tmp_path): + """Large query dataset: 10 shards x 1000 samples = 10000 total.""" + samples = generate_manifest_samples(10000) + create_sharded_dataset( + tmp_path, samples, 1000, BenchManifestSample, with_manifests=True + ) + executor = QueryExecutor.from_directory(tmp_path) + return executor, tmp_path + + +# ============================================================================= +# Query Predicate Benchmarks +# ============================================================================= + + +@pytest.mark.bench_query +class TestQueryPredicateBenchmarks: + """Benchmark different query predicate types on a medium dataset.""" + + def test_query_simple_equality(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + benchmark(executor.query, where=lambda df: df["label"] == "dog") + + def test_query_numeric_range(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + benchmark(executor.query, where=lambda df: df["confidence"] > 0.8) + + def test_query_combined(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + benchmark( + executor.query, + where=lambda df: (df["label"] == "dog") & (df["confidence"] > 0.8), + ) + + def test_query_isin(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + benchmark( + executor.query, + where=lambda df: df["label"].isin(["dog", "cat"]), + ) + + def test_query_no_results(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + benchmark( + executor.query, + where=lambda df: df["confidence"] > 999.0, + ) + + def test_query_all_results(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + benchmark( + executor.query, + where=lambda df: df["confidence"] >= 0.0, + ) + + +# ============================================================================= +# Query Result Iteration Benchmarks +# ============================================================================= + + +@pytest.mark.bench_query +class TestQueryIterationBenchmarks: + """Benchmark iterating through query results to access sample locations.""" + + def test_iterate_equality_results(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + # Pre-run to capture result count for per-result normalization + n = len(executor.query(where=lambda df: df["label"] == "dog")) + benchmark.extra_info["n_samples"] = n + + def _query_and_iterate(): + results = executor.query(where=lambda df: df["label"] == "dog") + keys = [loc.key for loc in results] + return len(keys) + + count = benchmark(_query_and_iterate) + assert count == n + + def test_iterate_range_results(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + n = len(executor.query(where=lambda df: df["confidence"] > 0.5)) + benchmark.extra_info["n_samples"] = n + + def _query_and_iterate(): + results = executor.query(where=lambda df: df["confidence"] > 0.5) + by_shard: dict[str, list[int]] = {} + for loc in results: + by_shard.setdefault(loc.shard, []).append(loc.offset) + return sum(len(v) for v in by_shard.values()) + + count = benchmark(_query_and_iterate) + assert count == n + + def test_iterate_large_result_set(self, benchmark, query_dataset_large): + executor, _ = query_dataset_large + benchmark.extra_info["n_samples"] = 10000 + + def _query_and_iterate(): + results = executor.query(where=lambda df: df["confidence"] >= 0.0) + keys = [loc.key for loc in results] + return len(keys) + + count = benchmark(_query_and_iterate) + assert count == 10000 + + +# ============================================================================= +# Scale Benchmarks +# ============================================================================= + + +@pytest.mark.bench_query +class TestQueryScaleBenchmarks: + """Benchmark query performance at different scales.""" + + def test_query_small(self, benchmark, query_dataset_small): + executor, _ = query_dataset_small + benchmark(executor.query, where=lambda df: df["confidence"] > 0.5) + + def test_query_medium(self, benchmark, query_dataset_medium): + executor, _ = query_dataset_medium + benchmark(executor.query, where=lambda df: df["confidence"] > 0.5) + + def test_query_large(self, benchmark, query_dataset_large): + executor, _ = query_dataset_large + benchmark(executor.query, where=lambda df: df["confidence"] > 0.5) + + +# ============================================================================= +# Manifest Loading Benchmarks +# ============================================================================= + + +@pytest.mark.bench_query +class TestManifestLoadBenchmarks: + """Benchmark manifest loading from disk.""" + + PARAM_LABELS = {"n_shards": "number of shards (100 samples each)"} + + @pytest.mark.parametrize("n_shards", [2, 5, 10, 20], ids=["2s", "5s", "10s", "20s"]) + def test_load_from_directory(self, benchmark, tmp_path, n_shards): + total = n_shards * 100 + samples = generate_manifest_samples(total) + create_sharded_dataset( + tmp_path, samples, 100, BenchManifestSample, with_manifests=True + ) + benchmark(QueryExecutor.from_directory, tmp_path) + + @pytest.mark.parametrize("n_shards", [2, 5, 10], ids=["2s", "5s", "10s"]) + def test_load_from_shard_urls(self, benchmark, tmp_path, n_shards): + total = n_shards * 100 + samples = generate_manifest_samples(total) + tar_paths = create_sharded_dataset( + tmp_path, samples, 100, BenchManifestSample, with_manifests=True + ) + shard_urls = [str(p) for p in tar_paths] + benchmark(QueryExecutor.from_shard_urls, shard_urls) + + +# ============================================================================= +# Manifest Build Benchmarks +# ============================================================================= + + +@pytest.mark.bench_query +class TestManifestBuildBenchmarks: + """Benchmark manifest construction from samples.""" + + PARAM_LABELS = {"n": "samples in manifest"} + + @pytest.mark.parametrize("n", [100, 1000, 5000], ids=["100", "1k", "5k"]) + def test_manifest_build(self, benchmark, n): + benchmark.extra_info["n_samples"] = n + samples = generate_manifest_samples(n) + + def _build(): + builder = ManifestBuilder( + sample_type=BenchManifestSample, + shard_id="bench-shard-000000", + ) + offset = 0 + for i, sample in enumerate(samples): + packed_size = len(sample.packed) + builder.add_sample( + key=f"sample_{i:06d}", + offset=offset, + size=packed_size, + sample=sample, + ) + offset += 512 + packed_size + return builder.build() + + manifest = benchmark(_build) + assert manifest.num_samples == n + + @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"]) + def test_manifest_write(self, benchmark, tmp_path, n): + benchmark.extra_info["n_samples"] = n + samples = generate_manifest_samples(n) + + builder = ManifestBuilder( + sample_type=BenchManifestSample, + shard_id="bench-write-000000", + ) + offset = 0 + for i, sample in enumerate(samples): + packed_size = len(sample.packed) + builder.add_sample( + key=f"sample_{i:06d}", + offset=offset, + size=packed_size, + sample=sample, + ) + offset += 512 + packed_size + manifest = builder.build() + + counter = [0] + + def _write(): + idx = counter[0] + counter[0] += 1 + writer = ManifestWriter(tmp_path / f"mw-{idx:06d}") + writer.write(manifest) + + benchmark(_write) diff --git a/benchmarks/conftest.py b/benchmarks/conftest.py new file mode 100644 index 0000000..2c7f84a --- /dev/null +++ b/benchmarks/conftest.py @@ -0,0 +1,345 @@ +"""Shared fixtures and helpers for performance benchmarks.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Annotated, Any + +import numpy as np +import pytest +import webdataset as wds +from numpy.typing import NDArray, DTypeLike + +import atdata +from atdata.manifest import ManifestBuilder, ManifestField, ManifestWriter + + +# ============================================================================= +# Benchmark Sample Types +# ============================================================================= + + +@atdata.packable +class BenchBasicSample: + """Lightweight sample for throughput benchmarks.""" + + name: str + value: int + + +@atdata.packable +class BenchNumpySample: + """Sample with NDArray for array serialization benchmarks.""" + + data: NDArray + label: str + + +@atdata.packable +class BenchManifestSample: + """Sample with manifest-annotated fields for query benchmarks.""" + + data: NDArray + label: Annotated[str, ManifestField("categorical")] + confidence: Annotated[float, ManifestField("numeric")] + tags: Annotated[list[str], ManifestField("set")] + + +# ============================================================================= +# Benchmark Constants +# ============================================================================= + +# Standard image: 3-channel 224x224 uint8 (ImageNet-style) +IMAGE_SHAPE = (3, 224, 224) +IMAGE_DTYPE = np.uint8 + +# Large biological timeseries: 1024x1024 spatial x 600 frames, float32 +TSERIES_SHAPE = (1024, 1024, 60) +TSERIES_DTYPE = np.float32 + +# Small array for manifest/overhead benchmarks (keeps manifests fast) +MANIFEST_ARRAY_SHAPE = (4, 4) +MANIFEST_ARRAY_DTYPE = np.float32 + + +# ============================================================================= +# Sample Generators +# ============================================================================= + +LABELS = ["dog", "cat", "bird", "fish", "horse"] +TAG_POOLS = [ + ["outdoor", "day"], + ["indoor"], + ["outdoor", "night"], + ["underwater"], + ["field", "day"], +] + + +def generate_basic_samples(n: int) -> list[BenchBasicSample]: + return [BenchBasicSample(name=f"sample_{i:06d}", value=i) for i in range(n)] + + +def generate_numpy_samples( + n: int, + shape: tuple[int, ...] = IMAGE_SHAPE, + dtype: np.dtype = IMAGE_DTYPE, +) -> list[BenchNumpySample]: + return [ + BenchNumpySample( + data=np.random.randint(0, 256, size=shape, dtype=dtype) + if np.issubdtype(dtype, np.integer) + else np.random.randn(*shape).astype(dtype), + label=f"array_{i:06d}", + ) + for i in range(n) + ] + + +def generate_manifest_samples( + n: int, shape: tuple[int, ...] = (4, 4) +) -> list[BenchManifestSample]: + return [ + BenchManifestSample( + data=np.random.randn(*shape).astype(np.float32), + label=LABELS[i % len(LABELS)], + confidence=0.1 + 0.9 * (i % 100) / 100.0, + tags=TAG_POOLS[i % len(TAG_POOLS)], + ) + for i in range(n) + ] + + +# ============================================================================= +# Tar/Dataset Helpers +# ============================================================================= + + +def write_tar(tar_path: Path, samples: list) -> Path: + """Write samples to a WebDataset tar file.""" + tar_path.parent.mkdir(parents=True, exist_ok=True) + with wds.writer.TarWriter(str(tar_path)) as writer: + for sample in samples: + writer.write(sample.as_wds) + return tar_path + + +def write_tar_with_manifest( + tar_path: Path, + samples: list, + sample_type: type, +) -> tuple[Path, Path, Path]: + """Write samples to a tar file and generate companion manifest files. + + Returns: + Tuple of (tar_path, json_path, parquet_path). + """ + tar_path.parent.mkdir(parents=True, exist_ok=True) + shard_name = tar_path.stem + shard_id = str(tar_path.parent / shard_name) + + builder = ManifestBuilder(sample_type=sample_type, shard_id=shard_id) + + offset = 0 + with wds.writer.TarWriter(str(tar_path)) as writer: + for sample in samples: + wds_dict = sample.as_wds + writer.write(wds_dict) + packed_size = len(wds_dict.get("msgpack", b"")) + builder.add_sample( + key=wds_dict["__key__"], + offset=offset, + size=packed_size, + sample=sample, + ) + offset += 512 + packed_size + (512 - packed_size % 512) % 512 + + manifest = builder.build() + manifest_writer = ManifestWriter(tar_path.parent / shard_name) + json_path, parquet_path = manifest_writer.write(manifest) + + return tar_path, json_path, parquet_path + + +def create_sharded_dataset( + base_dir: Path, + samples: list, + samples_per_shard: int, + sample_type: type, + with_manifests: bool = False, +) -> list[Path]: + """Split samples across multiple shards. Returns list of tar paths.""" + tar_paths: list[Path] = [] + for shard_idx in range(0, len(samples), samples_per_shard): + chunk = samples[shard_idx : shard_idx + samples_per_shard] + shard_name = f"data-{shard_idx // samples_per_shard:06d}" + tar_path = base_dir / f"{shard_name}.tar" + + if with_manifests: + write_tar_with_manifest(tar_path, chunk, sample_type) + else: + write_tar(tar_path, chunk) + + tar_paths.append(tar_path) + + return tar_paths + + +# ============================================================================= +# Provider Fixtures +# ============================================================================= + + +@pytest.fixture +def sqlite_provider(tmp_path): + """Fresh SQLite provider in a temp directory.""" + from atdata.providers._sqlite import SqliteProvider + + provider = SqliteProvider(path=tmp_path / "bench.db") + yield provider + provider.close() + + +@pytest.fixture +def redis_provider(): + """Real Redis provider, skip if unavailable.""" + from redis import Redis + + try: + conn = Redis() + conn.ping() + except Exception: + pytest.skip("Redis server not available") + + from atdata.providers._redis import RedisProvider + + provider = RedisProvider(conn) + + # Clean up benchmark keys before/after + def _cleanup(): + for pattern in ("LocalDatasetEntry:bench_*", "LocalSchema:Bench*"): + for key in conn.scan_iter(match=pattern): + conn.delete(key) + + _cleanup() + yield provider + _cleanup() + provider.close() + + +@pytest.fixture +def postgres_provider(): + """Real PostgreSQL provider, skip if unavailable.""" + try: + import psycopg # noqa: F401 + except ImportError: + pytest.skip("psycopg not installed") + + import os + + dsn = os.environ.get("ATDATA_BENCH_POSTGRES_DSN") + if not dsn: + pytest.skip("ATDATA_BENCH_POSTGRES_DSN not set") + + from atdata.providers._postgres import PostgresProvider + + provider = PostgresProvider(dsn=dsn) + yield provider + provider.close() + + +@pytest.fixture( + params=["sqlite", "redis", "postgres"], + ids=["sqlite", "redis", "postgres"], +) +def any_provider(request, tmp_path): + """Parametrized fixture that yields each available provider.""" + backend = request.param + + if backend == "sqlite": + from atdata.providers._sqlite import SqliteProvider + + provider = SqliteProvider(path=tmp_path / "bench.db") + yield provider + provider.close() + + elif backend == "redis": + from redis import Redis + + try: + conn = Redis() + conn.ping() + except Exception: + pytest.skip("Redis server not available") + + from atdata.providers._redis import RedisProvider + + provider = RedisProvider(conn) + + def _cleanup(): + for pattern in ("LocalDatasetEntry:bench_*", "LocalSchema:Bench*"): + for key in conn.scan_iter(match=pattern): + conn.delete(key) + + _cleanup() + yield provider + _cleanup() + provider.close() + + elif backend == "postgres": + try: + import psycopg # noqa: F401 + except ImportError: + pytest.skip("psycopg not installed") + + import os + + dsn = os.environ.get("ATDATA_BENCH_POSTGRES_DSN") + if not dsn: + pytest.skip("ATDATA_BENCH_POSTGRES_DSN not set") + + from atdata.providers._postgres import PostgresProvider + + provider = PostgresProvider(dsn=dsn) + yield provider + provider.close() + + +# ============================================================================= +# Dataset Fixtures +# ============================================================================= + + +@pytest.fixture +def small_basic_dataset(tmp_path): + """100-sample basic dataset (1 shard).""" + samples = generate_basic_samples(100) + tar_path = write_tar(tmp_path / "small-000000.tar", samples) + return atdata.Dataset[BenchBasicSample](url=str(tar_path)), samples + + +@pytest.fixture +def medium_basic_dataset(tmp_path): + """1000-sample basic dataset (1 shard).""" + samples = generate_basic_samples(1000) + tar_path = write_tar(tmp_path / "medium-000000.tar", samples) + return atdata.Dataset[BenchBasicSample](url=str(tar_path)), samples + + +@pytest.fixture +def small_numpy_dataset(tmp_path): + """100-sample numpy dataset (1 shard, 10x10 arrays).""" + samples = generate_numpy_samples(100) + tar_path = write_tar(tmp_path / "numpy-000000.tar", samples) + return atdata.Dataset[BenchNumpySample](url=str(tar_path)), samples + + +@pytest.fixture +def manifest_dataset_small(tmp_path): + """100-sample manifest dataset with 2 shards.""" + samples = generate_manifest_samples(100) + create_sharded_dataset( + tmp_path, samples, 50, BenchManifestSample, with_manifests=True + ) + return tmp_path, samples diff --git a/benchmarks/render_report.py b/benchmarks/render_report.py new file mode 100644 index 0000000..9cc1cd4 --- /dev/null +++ b/benchmarks/render_report.py @@ -0,0 +1,462 @@ +"""Render pytest-benchmark JSON results into a standalone HTML report. + +Usage: + uv run python -m benchmarks.render_report results/*.json -o bench-report.html + +Reads one or more pytest-benchmark JSON files (one per group) and produces +a single HTML page with grouped tables and test descriptions extracted from +the benchmark source files. +""" + +from __future__ import annotations + +import argparse +import ast +import json +import sys +from dataclasses import dataclass, field +from pathlib import Path + +import jinja2 + + +# ============================================================================= +# Docstring extraction +# ============================================================================= + + +@dataclass +class _SourceMeta: + """Metadata extracted from benchmark source files via AST.""" + + docstrings: dict[str, str] = field(default_factory=dict) + param_labels: dict[str, dict[str, str]] = field(default_factory=dict) + + +def _extract_source_meta(bench_dir: Path) -> _SourceMeta: + """Walk benchmark .py files and extract class/method docstrings and PARAM_LABELS. + + Returns a ``_SourceMeta`` with: + - ``docstrings``: mapping qualified names like + ``"TestSerializationBenchmarks"`` or + ``"TestSerializationBenchmarks.test_serialize_basic_sample"`` + to their docstring text. + - ``param_labels``: mapping class names to their ``PARAM_LABELS`` dict + (e.g. ``{"n": "samples per shard"}``). + """ + meta = _SourceMeta() + for py_file in sorted(bench_dir.glob("bench_*.py")): + tree = ast.parse(py_file.read_text()) + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + cls_doc = ast.get_docstring(node) + if cls_doc: + meta.docstrings[node.name] = cls_doc + for item in node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + fn_doc = ast.get_docstring(item) + if fn_doc: + meta.docstrings[f"{node.name}.{item.name}"] = fn_doc + # Extract PARAM_LABELS = {...} class variable + if isinstance(item, ast.Assign): + for target in item.targets: + if ( + isinstance(target, ast.Name) + and target.id == "PARAM_LABELS" + ): + try: + value = ast.literal_eval(item.value) + if isinstance(value, dict): + meta.param_labels[node.name] = value + except (ValueError, TypeError): + pass + return meta + + +# ============================================================================= +# Data model +# ============================================================================= + +# Friendly names for each marker group +GROUP_TITLES: dict[str, str] = { + "bench_serial": "Serialization (μs scale)", + "bench_index": "Index Providers (μs–ms scale)", + "bench_io": "Dataset I/O (ms scale)", + "bench_query": "Query System (ms scale)", + "bench_s3": "S3 Storage (ms+ scale)", +} + + +@dataclass +class BenchRow: + name: str + fullname: str + description: str + raw_params: str + params_desc: str + median: float + iqr: float + ops: float + n_samples: int | None = None + + +@dataclass +class BenchGroup: + key: str + title: str + class_description: str + param_labels: dict[str, str] = field(default_factory=dict) + has_samples: bool = False + rows: list[BenchRow] = field(default_factory=list) + + +def _format_time(seconds: float) -> str: + """Format seconds into a human-readable string with appropriate unit.""" + if seconds < 1e-6: + return f"{seconds * 1e9:.1f} ns" + if seconds < 1e-3: + return f"{seconds * 1e6:.2f} μs" + if seconds < 1.0: + return f"{seconds * 1e3:.2f} ms" + return f"{seconds:.3f} s" + + +def _format_ops(ops: float) -> str: + """Format operations per second with SI prefix.""" + if ops >= 1e6: + return f"{ops / 1e6:.2f} Mops/s" + if ops >= 1e3: + return f"{ops / 1e3:.2f} Kops/s" + return f"{ops:.1f} ops/s" + + +def _class_from_fullname(fullname: str) -> str: + """Extract class name from a pytest fullname like + ``benchmarks/bench_dataset_io.py::TestSerializationBenchmarks::test_foo``. + """ + parts = fullname.split("::") + if len(parts) >= 2: + return parts[1] + return "" + + +def _method_from_fullname(fullname: str) -> str: + """Extract method name from a pytest fullname.""" + parts = fullname.split("::") + if len(parts) >= 3: + return parts[2] + return parts[-1] + + +# ============================================================================= +# Build groups from JSON +# ============================================================================= + + +def _format_params( + params_dict: dict | None, + param_labels: dict[str, str], +) -> str: + """Format a benchmark's params dict using human-readable labels. + + Given ``{"n": 1000}`` and labels ``{"n": "samples per shard"}``, + returns ``"1000 samples per shard"``. + """ + if not params_dict: + return "" + parts: list[str] = [] + for key, value in params_dict.items(): + label = param_labels.get(key) + if label: + parts.append(f"{value} {label}") + else: + parts.append(f"{key}={value}") + return ", ".join(parts) + + +def _build_groups( + json_paths: list[Path], + meta: _SourceMeta, +) -> list[BenchGroup]: + """Load JSON files and assemble BenchGroup objects.""" + # Each JSON file corresponds to one marker group. We infer the group key + # from the filename produced by the justfile (e.g. ``serial.json``). + groups: dict[str, BenchGroup] = {} + + for path in sorted(json_paths): + data = json.loads(path.read_text()) + benchmarks = data.get("benchmarks", []) + if not benchmarks: + continue + + group_key = path.stem # e.g. "serial", "index", "io", "query", "s3" + marker_key = f"bench_{group_key}" + title = GROUP_TITLES.get(marker_key, group_key.title()) + + # Collect unique class descriptions and param labels for the group + class_names_seen: set[str] = set() + class_descs: list[str] = [] + merged_param_labels: dict[str, str] = {} + rows: list[BenchRow] = [] + + for bench in benchmarks: + fullname = bench["fullname"] + stats = bench["stats"] + cls_name = _class_from_fullname(fullname) + method_name = _method_from_fullname(fullname) + + if cls_name and cls_name not in class_names_seen: + class_names_seen.add(cls_name) + cls_doc = meta.docstrings.get(cls_name, "") + if cls_doc: + class_descs.append(cls_doc) + # Merge param labels from this class + cls_labels = meta.param_labels.get(cls_name, {}) + for k, v in cls_labels.items(): + if k not in merged_param_labels: + merged_param_labels[k] = v + + # Build description: prefer method docstring, fall back to + # readable version of test name (strip bracket suffix) + base_method = method_name.split("[")[0] + qualified = f"{cls_name}.{base_method}" if cls_name else base_method + desc = meta.docstrings.get(qualified, "") + if not desc: + # Convert test_serialize_basic_sample -> Serialize basic sample + readable = base_method.removeprefix("test_").replace("_", " ").capitalize() + desc = readable + + # Raw param ID for the top line (e.g. "sqlite-5v") + raw_param = bench.get("param") or "" + # Human-readable param description for the subtitle + params_dict = bench.get("params") + params_desc = _format_params(params_dict, merged_param_labels) + + extra = bench.get("extra_info", {}) + n_samples = extra.get("n_samples") + if n_samples is not None: + n_samples = int(n_samples) + + rows.append( + BenchRow( + name=bench["name"], + fullname=fullname, + description=desc, + raw_params=str(raw_param) if raw_param else "", + params_desc=params_desc, + median=stats["median"], + iqr=stats["iqr"], + ops=stats["ops"], + n_samples=n_samples, + ) + ) + + groups[group_key] = BenchGroup( + key=group_key, + title=title, + class_description="; ".join(class_descs) if class_descs else "", + param_labels=merged_param_labels, + has_samples=any(r.n_samples is not None for r in rows), + rows=rows, + ) + + return list(groups.values()) + + +# ============================================================================= +# HTML template +# ============================================================================= + +TEMPLATE = jinja2.Template( + """\ + + + + + +atdata benchmark report + + + +

atdata benchmark report

+

Generated from pytest-benchmark JSON output

+ +{% if machine %} +
+ {{ machine.node }} — + {{ machine.cpu.brand_raw }} ({{ machine.cpu.count }} cores) · + Python {{ machine.python_version }} · + {{ machine.system }} {{ machine.release }} + {% if commit %} + · {{ commit.branch }}@{{ commit.id[:8] }}{% if commit.dirty %} (dirty){% endif %} + {% endif %} +
+{% endif %} + +{% for group in groups %} +
+

{{ group.title }}

+ {% if group.class_description %} +

{{ group.class_description }}

+ {% endif %} + {% if group.param_labels %} +

Parameters: {% for key, label in group.param_labels.items() %}{{ key }} = {{ label }}{% if not loop.last %}, {% endif %}{% endfor %}

+ {% endif %} + + + + + + + + {% if group.has_samples %} + {% endif %} + + + + {% for row in group.rows %} + + + + + + {% if group.has_samples %} + {% endif %} + + {% endfor %} + +
TestMedianIQROPSMed/sampleSamples/s
+ {% if "[" in row.name %}{{ row.name.split("[")[0] }}[{{ row.name.split("[")[1] }}{% else %}{{ row.name }}{% endif %} +
{{ row.description }}{% if row.params_desc %} [{{ row.params_desc }}]{% endif %} +
{{ fmt_time(row.median) }}{{ fmt_time(row.iqr) }}{{ fmt_ops(row.ops) }}{% if row.n_samples %}{{ fmt_time(row.median / row.n_samples) }}{% else %}—{% endif %}{% if row.n_samples %}{{ fmt_ops(row.n_samples / row.median) }}{% else %}—{% endif %}
+
+{% endfor %} + + + + +""", + undefined=jinja2.StrictUndefined, +) + + +# ============================================================================= +# Main +# ============================================================================= + + +def render_html(json_paths: list[Path], bench_dir: Path) -> str: + """Render benchmark JSON files into an HTML string.""" + meta = _extract_source_meta(bench_dir) + groups = _build_groups(json_paths, meta) + + # Extract machine/commit info from the first JSON file + machine = None + commit = None + for path in json_paths: + data = json.loads(path.read_text()) + if "machine_info" in data: + machine = data["machine_info"] + commit = data.get("commit_info") + break + + total = sum(len(g.rows) for g in groups) + + return TEMPLATE.render( + groups=groups, + machine=machine, + commit=commit, + total_benchmarks=total, + fmt_time=_format_time, + fmt_ops=_format_ops, + ) + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Render pytest-benchmark JSON into HTML report", + ) + parser.add_argument( + "json_files", + nargs="+", + type=Path, + help="One or more pytest-benchmark JSON files", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=Path("bench-report.html"), + help="Output HTML file (default: bench-report.html)", + ) + parser.add_argument( + "--bench-dir", + type=Path, + default=Path(__file__).parent, + help="Directory containing bench_*.py files for docstring extraction", + ) + args = parser.parse_args(argv) + + existing = [p for p in args.json_files if p.exists()] + if not existing: + print(f"Error: no JSON files found: {args.json_files}", file=sys.stderr) + sys.exit(1) + + html = render_html(existing, args.bench_dir) + args.output.write_text(html) + print(f"Wrote {args.output} ({len(html):,} bytes, from {len(existing)} JSON files)") + + +if __name__ == "__main__": + main() diff --git a/docs/api/AbstractDataStore.html b/docs/api/AbstractDataStore.html index b4b9393..4ceb07c 100644 --- a/docs/api/AbstractDataStore.html +++ b/docs/api/AbstractDataStore.html @@ -71,14 +71,10 @@ - - - + - - - + + + - +
@@ -290,10 +145,6 @@
-
@@ -411,7 +261,7 @@

On this page

- +
@@ -419,7 +269,6 @@

On this page

-

AbstractDataStore

AbstractDataStore()
@@ -602,19 +451,6 @@

- - - + - - - + + + - +
@@ -290,10 +145,6 @@
-
@@ -417,7 +267,7 @@

On this page

- +
@@ -425,7 +275,6 @@

On this page

-

AbstractIndex

AbstractIndex()
@@ -925,19 +774,6 @@

- - - + - - - + + + - +
@@ -290,10 +145,6 @@
-
@@ -410,7 +260,7 @@

On this page

- +
@@ -418,7 +268,6 @@

On this page

-

AtUri

atmosphere.AtUri(authority, collection, rkey)
@@ -547,19 +396,6 @@

Rais

- - - + - - - + + + - +
@@ -290,10 +145,6 @@
-
@@ -424,7 +274,7 @@

On this page

- +
@@ -432,7 +282,6 @@

On this page

-

AtmosphereClient

atmosphere.AtmosphereClient(base_url=None, *, _client=None)
@@ -1460,19 +1309,6 @@

Ra

- - - + - - - + + + - +
@@ -290,10 +145,6 @@
-
@@ -416,7 +266,7 @@

On this page

- +
@@ -424,7 +274,6 @@

On this page

-

AtmosphereIndex

atmosphere.AtmosphereIndex(client, *, data_store=None)
@@ -930,19 +779,6 @@

- - - + - - - + + + - +
@@ -290,10 +145,6 @@
-
@@ -405,7 +255,7 @@

On this page

  • Attributes
  • - +
    @@ -413,7 +263,6 @@

    On this page

    -

    AtmosphereIndexEntry

    atmosphere.AtmosphereIndexEntry(uri, record)
    @@ -449,19 +298,6 @@

    window.document.addEventListener("DOMContentLoaded", function (event) { - // Ensure there is a toggle, if there isn't float one in the top right - if (window.document.querySelector('.quarto-color-scheme-toggle') === null) { - const a = window.document.createElement('a'); - a.classList.add('top-right'); - a.classList.add('quarto-color-scheme-toggle'); - a.href = ""; - a.onclick = function() { try { window.quartoToggleColorScheme(); } catch {} return false; }; - const i = window.document.createElement("i"); - i.classList.add('bi'); - a.appendChild(i); - window.document.body.appendChild(a); - } - setColorSchemeToggle(hasAlternateSentinel()) const icon = ""; const anchorJS = new window.AnchorJS(); anchorJS.options = { @@ -532,7 +368,7 @@

    { return filterRegex.test(href) || localhostRegex.test(href) || mailtoRegex.test(href); } @@ -867,7 +703,7 @@

      - + diff --git a/docs/api/BlobSource.html b/docs/api/BlobSource.html index 5b15b69..6797e32 100644 --- a/docs/api/BlobSource.html +++ b/docs/api/BlobSource.html @@ -71,14 +71,10 @@ - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -412,7 +262,7 @@

    On this page

    - +
    @@ -420,7 +270,6 @@

    On this page

    -

    BlobSource

    BlobSource(blob_refs, pds_endpoint=None, _endpoint_cache=dict())
    @@ -640,19 +489,6 @@

    Ra

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    DataSource

    DataSource()
    @@ -573,19 +422,6 @@

    Rais

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -418,7 +268,7 @@

    On this page

    - +
    @@ -426,7 +276,6 @@

    On this page

    -

    Dataset

    Dataset(source=None, metadata_url=None, *, url=None)
    @@ -513,7 +362,7 @@

    Methods

    ordered -Iterate over the dataset in order +Iterate over the dataset in order. shuffled @@ -639,16 +488,10 @@

    ordered

    Dataset.ordered(batch_size=None)
    -

    Iterate over the dataset in order

    +

    Iterate over the dataset in order.

    Parameters

    ------ @@ -659,10 +502,10 @@

    -

    - - - + + + +
    Namebatch_size (obj:int, optional): The size of iterated batches. Default: None (unbatched). If None, iterates over one sample at a time with no batch dimension.requiredbatch_sizeint | NoneThe size of iterated batches. Default: None (unbatched). If None, iterates over one sample at a time with no batch dimension.None
    @@ -680,21 +523,38 @@

    -Iterable[ST] -obj:webdataset.DataPipeline A data pipeline that iterates over +Iterable[ST] | Iterable[SampleBatch[ST]] +A data pipeline that iterates over the dataset in its original -Iterable[ST] -the dataset in its original sample order +Iterable[ST] | Iterable[SampleBatch[ST]] +sample order. When batch_size is None, yields individual + + + +Iterable[ST] | Iterable[SampleBatch[ST]] +samples of type ST. When batch_size is an integer, yields + + + +Iterable[ST] | Iterable[SampleBatch[ST]] +SampleBatch[ST] instances containing that many samples.

    +
    +

    Examples

    +
    >>> for sample in ds.ordered():
    +...     process(sample)  # sample is ST
    +>>> for batch in ds.ordered(batch_size=32):
    +...     process(batch)  # batch is SampleBatch[ST]
    +

    shuffled

    -
    Dataset.shuffled(buffer_shards=100, buffer_samples=10000, batch_size=None)
    +
    Dataset.shuffled(buffer_shards=100, buffer_samples=10000, batch_size=None)

    Iterate over the dataset in random order.

    Parameters

    @@ -742,31 +602,38 @@

    -Iterable[ST] -A WebDataset data pipeline that iterates over the dataset in +Iterable[ST] | Iterable[SampleBatch[ST]] +A data pipeline that iterates over the dataset in randomized order. -Iterable[ST] -randomized order. If batch_size is not None, yields +Iterable[ST] | Iterable[SampleBatch[ST]] +When batch_size is None, yields individual samples of type -Iterable[ST] -SampleBatch[ST] instances; otherwise yields individual ST +Iterable[ST] | Iterable[SampleBatch[ST]] +ST. When batch_size is an integer, yields SampleBatch[ST] -Iterable[ST] -samples. +Iterable[ST] | Iterable[SampleBatch[ST]] +instances containing that many samples.

    +
    +

    Examples

    +
    >>> for sample in ds.shuffled():
    +...     process(sample)  # sample is ST
    +>>> for batch in ds.shuffled(batch_size=32):
    +...     process(batch)  # batch is SampleBatch[ST]
    +

    to_parquet

    -
    Dataset.to_parquet(path, sample_map=None, maxcount=None, **kwargs)
    +
    Dataset.to_parquet(path, sample_map=None, maxcount=None, **kwargs)

    Export dataset contents to parquet format.

    Converts all samples to a pandas DataFrame and saves to parquet file(s). Useful for interoperability with data analysis tools.

    @@ -816,19 +683,19 @@

    Wa ds.to_parquet("output.parquet", maxcount=10000)

    This creates multiple parquet files: output-000000.parquet, output-000001.parquet, etc.

    -
    -

    Examples

    -
    >>> ds = Dataset[MySample]("data.tar")
    ->>> # Small dataset - load all at once
    ->>> ds.to_parquet("output.parquet")
    ->>>
    ->>> # Large dataset - process in chunks
    ->>> ds.to_parquet("output.parquet", maxcount=50000)
    +
    +

    Examples

    +
    >>> ds = Dataset[MySample]("data.tar")
    +>>> # Small dataset - load all at once
    +>>> ds.to_parquet("output.parquet")
    +>>>
    +>>> # Large dataset - process in chunks
    +>>> ds.to_parquet("output.parquet", maxcount=50000)

    wrap

    -
    Dataset.wrap(sample)
    +
    Dataset.wrap(sample)

    Wrap a raw msgpack sample into the appropriate dataset-specific type.

    Parameters

    @@ -878,7 +745,7 @@

    wrap_batch

    -
    Dataset.wrap_batch(batch)
    +
    Dataset.wrap_batch(batch)

    Wrap a batch of raw msgpack samples into a typed SampleBatch.

    Parameters

    @@ -938,19 +805,6 @@

    Note - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -407,7 +257,7 @@

    On this page

  • Attributes
  • - +
    @@ -415,7 +265,6 @@

    On this page

    -

    DatasetDict

    DatasetDict(splits=None, sample_type=None, streaming=False)
    @@ -490,19 +339,6 @@

    Attributes

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -416,7 +266,7 @@

    On this page

    - +
    @@ -424,7 +274,6 @@

    On this page

    -

    DatasetLoader

    atmosphere.DatasetLoader(client)
    @@ -991,19 +840,6 @@

    window.document.addEventListener("DOMContentLoaded", function (event) { - // Ensure there is a toggle, if there isn't float one in the top right - if (window.document.querySelector('.quarto-color-scheme-toggle') === null) { - const a = window.document.createElement('a'); - a.classList.add('top-right'); - a.classList.add('quarto-color-scheme-toggle'); - a.href = ""; - a.onclick = function() { try { window.quartoToggleColorScheme(); } catch {} return false; }; - const i = window.document.createElement("i"); - i.classList.add('bi'); - a.appendChild(i); - window.document.body.appendChild(a); - } - setColorSchemeToggle(hasAlternateSentinel()) const icon = ""; const anchorJS = new window.AnchorJS(); anchorJS.options = { @@ -1074,7 +910,7 @@

    { return filterRegex.test(href) || localhostRegex.test(href) || mailtoRegex.test(href); } @@ -1409,7 +1245,7 @@

      - + diff --git a/docs/api/DatasetPublisher.html b/docs/api/DatasetPublisher.html index 2e69763..da5a1aa 100644 --- a/docs/api/DatasetPublisher.html +++ b/docs/api/DatasetPublisher.html @@ -71,14 +71,10 @@ - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    DatasetPublisher

    atmosphere.DatasetPublisher(client)
    @@ -802,19 +651,6 @@

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -417,7 +267,7 @@

    On this page

    - +
    @@ -425,7 +275,6 @@

    On this page

    -

    DictSample

    DictSample(_data=None, **kwargs)
    @@ -678,19 +527,6 @@

    values

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -406,7 +256,7 @@

    On this page

  • Attributes
  • - +
    @@ -414,7 +264,6 @@

    On this page

    -

    IndexEntry

    IndexEntry()
    @@ -460,19 +309,6 @@

    Attributes

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -414,7 +264,7 @@

    On this page

    - +
    @@ -422,7 +272,6 @@

    On this page

    -

    lens

    lens

    @@ -941,19 +790,6 @@

    window.document.addEventListener("DOMContentLoaded", function (event) { - // Ensure there is a toggle, if there isn't float one in the top right - if (window.document.querySelector('.quarto-color-scheme-toggle') === null) { - const a = window.document.createElement('a'); - a.classList.add('top-right'); - a.classList.add('quarto-color-scheme-toggle'); - a.href = ""; - a.onclick = function() { try { window.quartoToggleColorScheme(); } catch {} return false; }; - const i = window.document.createElement("i"); - i.classList.add('bi'); - a.appendChild(i); - window.document.body.appendChild(a); - } - setColorSchemeToggle(hasAlternateSentinel()) const icon = ""; const anchorJS = new window.AnchorJS(); anchorJS.options = { @@ -1024,7 +860,7 @@

    { return filterRegex.test(href) || localhostRegex.test(href) || mailtoRegex.test(href); } @@ -1359,7 +1195,7 @@

      - + diff --git a/docs/api/LensLoader.html b/docs/api/LensLoader.html index b3a3073..adf9921 100644 --- a/docs/api/LensLoader.html +++ b/docs/api/LensLoader.html @@ -71,14 +71,10 @@ - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    LensLoader

    atmosphere.LensLoader(client)
    @@ -643,19 +492,6 @@

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    LensPublisher

    atmosphere.LensPublisher(client)
    @@ -697,19 +546,6 @@

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -413,7 +263,7 @@

    On this page

    - +
    @@ -421,7 +271,6 @@

    On this page

    -

    PDSBlobStore

    atmosphere.PDSBlobStore(client)
    @@ -759,19 +608,6 @@

    Note

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    Packable

    Packable()
    @@ -498,19 +347,6 @@

    from_data

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    PackableSample

    PackableSample()
    @@ -576,19 +425,6 @@

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -413,7 +263,7 @@

    On this page

    - +
    @@ -421,7 +271,6 @@

    On this page

    -

    S3Source

    S3Source(
    @@ -767,19 +616,6 @@ 

    Ra

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -408,7 +258,7 @@

    On this page

  • Note
  • - +
    @@ -416,7 +266,6 @@

    On this page

    -

    SampleBatch

    SampleBatch(samples)
    @@ -486,19 +335,6 @@

    Note

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -410,7 +260,7 @@

    On this page

    - +
    @@ -418,7 +268,6 @@

    On this page

    -

    SchemaLoader

    atmosphere.SchemaLoader(client)
    @@ -582,19 +431,6 @@

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -409,7 +259,7 @@

    On this page

    - +
    @@ -417,7 +267,6 @@

    On this page

    -

    SchemaPublisher

    atmosphere.SchemaPublisher(client)
    @@ -569,19 +418,6 @@

    Rais

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    URLSource

    URLSource(url)
    @@ -548,19 +397,6 @@

    Rais

    - - - + - - - + + + - +
    @@ -255,10 +110,6 @@
    -
    @@ -375,7 +225,7 @@

    On this page

  • Promotion
  • - +
    @@ -383,7 +233,6 @@

    On this page

    -

    API Reference

    @@ -569,19 +418,6 @@

    Promotion

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -408,7 +258,7 @@

    On this page

  • Examples
  • - +
    @@ -416,7 +266,6 @@

    On this page

    -

    load_dataset

    load_dataset(
    @@ -566,19 +415,6 @@ 

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -424,7 +274,7 @@

    On this page

    - +
    @@ -432,7 +282,6 @@

    On this page

    -

    local.Index

    local.Index(
    @@ -1487,19 +1336,6 @@ 

    Ra

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -410,7 +260,7 @@

    On this page

    - +
    @@ -418,7 +268,6 @@

    On this page

    -

    local.LocalDatasetEntry

    local.LocalDatasetEntry(
    @@ -591,19 +440,6 @@ 

    window.document.addEventListener("DOMContentLoaded", function (event) { - // Ensure there is a toggle, if there isn't float one in the top right - if (window.document.querySelector('.quarto-color-scheme-toggle') === null) { - const a = window.document.createElement('a'); - a.classList.add('top-right'); - a.classList.add('quarto-color-scheme-toggle'); - a.href = ""; - a.onclick = function() { try { window.quartoToggleColorScheme(); } catch {} return false; }; - const i = window.document.createElement("i"); - i.classList.add('bi'); - a.appendChild(i); - window.document.body.appendChild(a); - } - setColorSchemeToggle(hasAlternateSentinel()) const icon = ""; const anchorJS = new window.AnchorJS(); anchorJS.options = { @@ -674,7 +510,7 @@

    { return filterRegex.test(href) || localhostRegex.test(href) || mailtoRegex.test(href); } @@ -1009,7 +845,7 @@

      -

    + diff --git a/docs/api/local.S3DataStore.html b/docs/api/local.S3DataStore.html index 95026d1..4f89741 100644 --- a/docs/api/local.S3DataStore.html +++ b/docs/api/local.S3DataStore.html @@ -71,14 +71,10 @@ - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -411,7 +261,7 @@

    On this page

    - +
    @@ -419,7 +269,6 @@

    On this page

    -

    local.S3DataStore

    local.S3DataStore(credentials, *, bucket)
    @@ -644,19 +493,6 @@

    Rais

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -402,12 +252,13 @@

    On this page

    - +
    @@ -415,13 +266,16 @@

    On this page

    -

    packable

    packable(cls)

    Decorator to convert a regular class into a PackableSample.

    This decorator transforms a class into a dataclass that inherits from PackableSample, enabling automatic msgpack serialization/deserialization with special handling for NDArray fields.

    The resulting class satisfies the Packable protocol, making it compatible with all atdata APIs that accept packable types (e.g., publish_schema, lens transformations, etc.).

    +
    +

    Type Checking

    +

    The return type is annotated as type[PackableSample] so that IDEs and type checkers recognize the PackableSample methods (packed, as_wds, from_bytes, etc.). The @dataclass_transform() decorator ensures that field access from the original class is also preserved for type checking.

    +

    Parameters

    @@ -456,17 +310,17 @@

    Re

    - + - + - + @@ -493,19 +347,6 @@

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -408,7 +258,7 @@

    On this page

  • Examples
  • - +
    @@ -416,7 +266,6 @@

    On this page

    -

    promote_to_atmosphere

    promote.promote_to_atmosphere(
    @@ -552,19 +401,6 @@ 

    - - - + - - - + + + - +
    @@ -290,10 +145,6 @@
    -
    @@ -395,9 +245,9 @@ - - +
    @@ -541,8 +391,7 @@

    On this page

    - +
    -
    -

    atdata

    +

    atdata

    A loose federation of distributed, typed datasets built on WebDataset

    @@ -587,11 +435,12 @@

    atdata

    -
    -

    atdata

    -

    A loose federation of distributed, typed datasets built on WebDataset.

    +
    +

    The Challenge

    Machine learning datasets are everywhere—training data, validation sets, embeddings, features, model outputs. Yet working with them often means:

    @@ -666,7 +515,7 @@

    Quick Example

    1. Define a Sample Type

    The @packable decorator creates a serializable dataclass:

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -681,7 +530,7 @@ 

    1. Define a Sample Ty

    2. Create and Write Samples

    Use WebDataset’s standard TarWriter:

    -
    +
    import webdataset as wds
     
     samples = [
    @@ -701,7 +550,7 @@ 

    2. Create and Wri

    3. Load and Iterate with Type Safety

    The generic Dataset[T] provides typed access:

    -
    +
    dataset = atdata.Dataset[ImageSample]("data-000000.tar")
     
     for batch in dataset.shuffled(batch_size=32):
    @@ -716,7 +565,7 @@ 

    Scaling Up

    Team Storage with Redis + S3

    When you’re ready to share with your team:

    -
    +
    from atdata.local import LocalIndex, S3DataStore
     
     # Connect to team infrastructure
    @@ -740,7 +589,7 @@ 

    Team Storage wi

    Federation with ATProto

    For public or cross-organization sharing:

    -
    +
    from atdata.atmosphere import AtmosphereClient, AtmosphereIndex, PDSBlobStore
     from atdata.promote import promote_to_atmosphere
     
    @@ -762,7 +611,7 @@ 

    Federation with AT

    HuggingFace-Style Loading

    For convenient access to datasets:

    -
    +
    from atdata import load_dataset
     
     # Load from local files
    @@ -847,19 +696,6 @@ 

    Next Steps

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -581,15 +431,14 @@

    On this page

  • Summary
  • Related
  • - +
    - -
    +
    -

    Architecture Overview

    +

    Architecture Overview

    @@ -657,7 +506,7 @@

    Core Components

    PackableSample: The Foundation

    Everything in atdata starts with PackableSample—a base class that makes Python dataclasses serializable with msgpack:

    -
    +
    @atdata.packable
     class ImageSample:
         image: NDArray       # Automatically converted to/from bytes
    @@ -680,7 +529,7 @@ 

    PackableSamp

    Dataset: Typed Iteration

    The Dataset[T] class wraps WebDataset tar archives with type information:

    -
    +
    dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
     
     for batch in dataset.shuffled(batch_size=32):
    @@ -704,7 +553,7 @@ 

    Dataset: Typed Ite

    SampleBatch: Automatic Aggregation

    When iterating with batch_size, atdata returns SampleBatch[T] objects that aggregate sample attributes:

    -
    +
    batch = SampleBatch[ImageSample](samples)
     
     # NDArray fields → stacked numpy array with batch dimension
    @@ -718,7 +567,7 @@ 

    SampleBa

    Lens: Schema Transformations

    Lenses enable viewing datasets through different schemas without duplicating data:

    -
    +
    @atdata.packable
     class SimplifiedSample:
         label: str
    @@ -755,7 +604,7 @@ 

    Local Index (Redis +
  • WebDataset tar shards
  • Any S3-compatible storage (AWS, MinIO, Cloudflare R2)
  • -
    +
    store = S3DataStore(credentials=creds, bucket="datasets")
     index = LocalIndex(data_store=store)
     
    @@ -783,7 +632,7 @@ 

    Atmosphere Index
  • Store actual data shards as ATProto blobs
  • Fully decentralized—no external dependencies
  • -
    +
    client = AtmosphereClient()
     client.login("handle.bsky.social", "app-password")
     
    @@ -801,7 +650,7 @@ 

    Protocol Abstraction

    AbstractIndex

    Common interface for both LocalIndex and AtmosphereIndex:

    -
    +
    def process_dataset(index: AbstractIndex, name: str):
         entry = index.get_dataset(name)
         schema = index.decode_schema(entry.schema_ref)
    @@ -817,7 +666,7 @@ 

    AbstractIndex

    AbstractDataStore

    Common interface for S3DataStore and PDSBlobStore:

    -
    +
    def write_to_store(store: AbstractDataStore, dataset: Dataset):
         urls = store.write_shards(dataset, prefix="data/v1")
         # Works with S3 or PDS blob storage
    @@ -838,7 +687,7 @@

    Data Flow: L

    A typical workflow progresses through three stages:

    Stage 1: Local Development

    -
    +
    # Define type and create samples
     @atdata.packable
     class MySample:
    @@ -856,7 +705,7 @@ 

    Stage 1: Local D

    Stage 2: Team Storage

    -
    +
    # Set up team storage
     store = S3DataStore(credentials=team_creds, bucket="team-datasets")
     index = LocalIndex(data_store=store)
    @@ -871,7 +720,7 @@ 

    Stage 2: Team Storage

    Stage 3: Federation

    -
    +
    # Promote to atmosphere
     client = AtmosphereClient()
     client.login("handle.bsky.social", "app-password")
    @@ -904,7 +753,7 @@ 

    Extension Points

    Custom DataSources

    Implement the DataSource protocol to add new storage backends:

    -
    +
    class MyCustomSource:
         def list_shards(self) -> list[str]: ...
         def open_shard(self, shard_id: str) -> IO[bytes]: ...
    @@ -916,7 +765,7 @@ 

    Custom DataSources

    Custom Lenses

    Register transformations between any PackableSample types:

    -
    +
    @atdata.lens
     def my_transform(src: SourceType) -> TargetType:
         return TargetType(...)
    @@ -929,7 +778,7 @@ 

    Custom Lenses

    Schema Extensions

    The schema format supports custom metadata for domain-specific needs:

    -
    +
    index.publish_schema(
         MySample,
         version="1.0.0",
    @@ -1004,19 +853,6 @@ 

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -577,15 +427,14 @@

    On this page

  • Complete Example
  • Related
  • - +
    - -
    +
    -

    Atmosphere (ATProto Integration)

    +

    Atmosphere (ATProto Integration)

    @@ -626,7 +475,7 @@

    Overview

    AtmosphereClient

    The client handles authentication and record operations:

    -
    +
    from atdata.atmosphere import AtmosphereClient
     
     client = AtmosphereClient()
    @@ -653,7 +502,7 @@ 

    AtmosphereClient

    Session Management

    Save and restore sessions to avoid re-authentication:

    -
    +
    # Export session for later
     session_string = client.export_session()
     
    @@ -665,7 +514,7 @@ 

    Session Management

    Custom PDS

    Connect to a custom PDS instead of bsky.social:

    -
    +
    client = AtmosphereClient(base_url="https://pds.example.com")
    @@ -673,7 +522,7 @@

    Custom PDS

    PDSBlobStore

    Store dataset shards as ATProto blobs for fully decentralized storage:

    -
    +
    from atdata.atmosphere import AtmosphereClient, PDSBlobStore
     
     client = AtmosphereClient()
    @@ -696,7 +545,7 @@ 

    PDSBlobStore

    Size Limits

    PDS blobs typically have size limits (often 50MB-5GB depending on the PDS). Use maxcount and maxsize parameters to control shard sizes:

    -
    +
    urls = store.write_shards(
         dataset,
         prefix="large-data/v1",
    @@ -709,7 +558,7 @@ 

    Size Limits

    BlobSource

    Read datasets stored as PDS blobs:

    -
    +
    from atdata import BlobSource
     
     # From blob references
    @@ -730,7 +579,7 @@ 

    BlobSource

    AtmosphereIndex

    The unified interface for ATProto operations, implementing the AbstractIndex protocol:

    -
    +
    from atdata.atmosphere import AtmosphereClient, AtmosphereIndex, PDSBlobStore
     
     client = AtmosphereClient()
    @@ -745,7 +594,7 @@ 

    AtmosphereIndex

    Publishing Schemas

    -
    +
    import atdata
     from numpy.typing import NDArray
     
    @@ -766,7 +615,7 @@ 

    Publishing Schemas

    Publishing Datasets

    -
    +
    dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
     
     entry = index.insert_dataset(
    @@ -784,7 +633,7 @@ 

    Publishing Datasets

    Listing and Retrieving

    -
    +
    # List your datasets
     for entry in index.list_datasets():
         print(f"{entry.name}: {entry.schema_ref}")
    @@ -810,7 +659,7 @@ 

    Lower-Level Publish

    For more control, use the individual publisher classes:

    SchemaPublisher

    -
    +
    from atdata.atmosphere import SchemaPublisher
     
     publisher = SchemaPublisher(client)
    @@ -826,7 +675,7 @@ 

    SchemaPublisher

    DatasetPublisher

    -
    +
    from atdata.atmosphere import DatasetPublisher
     
     publisher = DatasetPublisher(client)
    @@ -846,7 +695,7 @@ 

    Blob Storage

    There are two approaches to storing data as ATProto blobs:

    Approach 1: PDSBlobStore (Recommended)

    Use PDSBlobStore with AtmosphereIndex for automatic shard management:

    -
    +
    from atdata.atmosphere import PDSBlobStore, AtmosphereIndex
     
     store = PDSBlobStore(client)
    @@ -865,7 +714,7 @@ 

    Blob Storage

    Approach 2: Manual Blob Publishing

    For more control, use DatasetPublisher.publish_with_blobs() directly:

    -
    +
    import io
     import webdataset as wds
     
    @@ -885,7 +734,7 @@ 

    Blob Storage

    )

    Loading Blob-Stored Datasets

    -
    +
    from atdata.atmosphere import DatasetLoader
     from atdata import BlobSource
     
    @@ -909,7 +758,7 @@ 

    Blob Storage

    LensPublisher

    -
    +
    from atdata.atmosphere import LensPublisher
     
     publisher = LensPublisher(client)
    @@ -952,7 +801,7 @@ 

    Lower-Level LoadersFor direct access to records, use the loader classes:

    SchemaLoader

    -
    +
    from atdata.atmosphere import SchemaLoader
     
     loader = SchemaLoader(client)
    @@ -968,7 +817,7 @@ 

    SchemaLoader

    DatasetLoader

    -
    +
    from atdata.atmosphere import DatasetLoader
     
     loader = DatasetLoader(client)
    @@ -996,7 +845,7 @@ 

    DatasetLoader

    LensLoader

    -
    +
    from atdata.atmosphere import LensLoader
     
     loader = LensLoader(client)
    @@ -1021,7 +870,7 @@ 

    LensLoader

    AT URIs

    ATProto records are identified by AT URIs:

    -
    +
    from atdata.atmosphere import AtUri
     
     # Parse an AT URI
    @@ -1088,7 +937,7 @@ 

    Supported Field Type

    Complete Example

    This example shows the full workflow using PDSBlobStore for decentralized storage:

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -1159,7 +1008,7 @@ 

    Complete Example

    break

    For external URL storage (without PDSBlobStore):

    -
    +
    # Use AtmosphereIndex without data_store
     index = AtmosphereIndex(client)
     
    @@ -1189,19 +1038,6 @@ 

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -570,15 +420,14 @@

    On this page

  • Related
  • - +
    - -
    +
    -

    Datasets

    +

    Datasets

    @@ -603,7 +452,7 @@

    Datasets

    The Dataset class provides typed iteration over WebDataset tar files with automatic batching and lens transformations.

    Creating a Dataset

    -
    +
    import atdata
     from numpy.typing import NDArray
     
    @@ -626,7 +475,7 @@ 

    Data Sources

    URL Source (default)

    When you pass a string to Dataset, it automatically wraps it in a URLSource:

    -
    +
    # These are equivalent:
     dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
     dataset = atdata.Dataset[ImageSample](atdata.URLSource("data-{000000..000009}.tar"))
    @@ -635,7 +484,7 @@

    URL Source (default)

    S3 Source

    For private S3 buckets or S3-compatible storage (Cloudflare R2, MinIO), use S3Source:

    -
    +
    # From explicit credentials
     source = atdata.S3Source(
         bucket="my-bucket",
    @@ -673,7 +522,7 @@ 

    Iteration Modes

    Ordered Iteration

    Iterate through samples in their original order:

    -
    +
    # With batching (default batch_size=1)
     for batch in dataset.ordered(batch_size=32):
         images = batch.image  # numpy array (32, H, W, C)
    @@ -687,7 +536,7 @@ 

    Ordered Iteration

    Shuffled Iteration

    Iterate with randomized order at both shard and sample levels:

    -
    +
    for batch in dataset.shuffled(batch_size=32):
         # Samples are shuffled
         process(batch)
    @@ -718,7 +567,7 @@ 

    Shuffled Iteration

    SampleBatch

    When iterating with a batch_size, each iteration yields a SampleBatch with automatic attribute aggregation.

    -
    +
    @atdata.packable
     class Sample:
         features: NDArray  # shape (256,)
    @@ -738,7 +587,7 @@ 

    SampleBatch

    Type Transformations with Lenses

    View a dataset through a different sample type using registered lenses:

    -
    +
    @atdata.packable
     class SimplifiedSample:
         label: str
    @@ -760,7 +609,7 @@ 

    Dataset Properties

    Shard List

    Get the list of individual tar files:

    -
    +
    dataset = atdata.Dataset[Sample]("data-{000000..000009}.tar")
     shards = dataset.shard_list
     # ['data-000000.tar', 'data-000001.tar', ..., 'data-000009.tar']
    @@ -769,7 +618,7 @@

    Shard List

    Metadata

    Datasets can have associated metadata from a URL:

    -
    +
    dataset = atdata.Dataset[Sample](
         "data-{000000..000009}.tar",
         metadata_url="https://example.com/metadata.msgpack"
    @@ -783,7 +632,7 @@ 

    Metadata

    Writing Datasets

    Use WebDataset’s TarWriter or ShardWriter to create datasets:

    -
    +
    import webdataset as wds
     import numpy as np
     
    @@ -806,7 +655,7 @@ 

    Writing Datasets

    Parquet Export

    Export dataset contents to parquet format:

    -
    +
    # Export entire dataset
     dataset.to_parquet("output.parquet")
     
    @@ -857,7 +706,7 @@ 

    Dataset Properties

    Source

    Access the underlying DataSource:

    -
    +
    dataset = atdata.Dataset[Sample]("data.tar")
     source = dataset.source  # URLSource instance
     print(source.shard_list)  # ['data.tar']
    @@ -866,7 +715,7 @@

    Source

    Sample Type

    Get the type parameter used to create the dataset:

    -
    +
    dataset = atdata.Dataset[ImageSample]("data.tar")
     print(dataset.sample_type)  # <class 'ImageSample'>
     print(dataset.batch_type)   # SampleBatch[ImageSample]
    @@ -888,19 +737,6 @@

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -566,15 +416,14 @@

    On this page

  • S3 IAM Policy Example
  • - +
    - -
    +
    -

    Deployment Guide

    +

    Deployment Guide

    @@ -922,19 +771,6 @@

    S3 IAM Policy Exampl

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -553,15 +403,14 @@

    On this page

  • Example: Feature Extraction
  • Related
  • - +
    - -
    +
    -

    Lenses

    +

    Lenses

    @@ -595,7 +444,7 @@

    Overview

    Creating a Lens

    Use the @lens decorator to define a getter:

    -
    +
    import atdata
     from numpy.typing import NDArray
     
    @@ -625,7 +474,7 @@ 

    Creating a Lens

    Adding a Putter

    To enable bidirectional updates, add a putter:

    -
    +
    @simplify.putter
     def simplify_put(view: SimpleSample, source: FullSample) -> FullSample:
         return FullSample(
    @@ -645,7 +494,7 @@ 

    Adding a Putter

    Using Lenses with Datasets

    Lenses integrate with Dataset.as_type():

    -
    +
    dataset = atdata.Dataset[FullSample]("data-{000000..000009}.tar")
     
     # View through a different type
    @@ -660,7 +509,7 @@ 

    Using Lenses wi

    Direct Lens Usage

    Lenses can also be called directly:

    -
    +
    import numpy as np
     
     full = FullSample(
    @@ -685,25 +534,25 @@ 

    Direct Lens Usage

    Lens Laws

    Well-behaved lenses should satisfy these properties:

    - +

    If you get a view and immediately put it back, the source is unchanged:

    -
    +
    view = lens.get(source)
     assert lens.put(view, source) == source

    If you put a view, getting it back yields that view:

    -
    +
    updated = lens.put(view, source)
     assert lens.get(updated) == view

    Putting twice is equivalent to putting once with the final value:

    -
    +
    result1 = lens.put(v2, lens.put(v1, source))
     result2 = lens.put(v2, source)
     assert result1 == result2
    @@ -715,7 +564,7 @@

    Lens Laws

    Trivial Putter

    If no putter is defined, a trivial putter is used that ignores view updates:

    -
    +
    @atdata.lens
     def extract_label(src: FullSample) -> SimpleSample:
         return SimpleSample(label=src.label, confidence=src.confidence)
    @@ -729,7 +578,7 @@ 

    Trivial Putter

    LensNetwork Registry

    The LensNetwork is a singleton that stores all registered lenses:

    -
    +
    from atdata.lens import LensNetwork
     
     network = LensNetwork()
    @@ -746,7 +595,7 @@ 

    LensNetwork Registry<

    Example: Feature Extraction

    -
    +
    @atdata.packable
     class RawSample:
         audio: NDArray
    @@ -788,19 +637,6 @@ 

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -561,15 +411,14 @@

    On this page

  • Complete Example
  • Related
  • - +
    - -
    +
    -

    load_dataset API

    +

    load_dataset API

    @@ -604,7 +453,7 @@

    Overview

    Basic Usage

    -
    +
    import atdata
     from atdata import load_dataset
     from numpy.typing import NDArray
    @@ -627,7 +476,7 @@ 

    Basic Usage

    Path Formats

    WebDataset Brace Notation

    -
    +
    # Range notation
     ds = load_dataset("data-{000000..000099}.tar", MySample, split="train")
     
    @@ -637,7 +486,7 @@ 

    WebDataset Brace

    Glob Patterns

    -
    +
    # Match all tar files
     ds = load_dataset("path/to/*.tar", MySample)
     
    @@ -647,14 +496,14 @@ 

    Glob Patterns

    Local Directory

    -
    +
    # Scans for .tar files
     ds = load_dataset("./my-dataset/", MySample)

    Remote URLs

    -
    +
    # S3 (public buckets)
     ds = load_dataset("s3://bucket/data-{000..099}.tar", MySample, split="train")
     
    @@ -680,7 +529,7 @@ 

    Remote URLs

    Index Lookup

    -
    +
    from atdata.local import LocalIndex
     
     index = LocalIndex()
    @@ -747,7 +596,7 @@ 

    Split Detection

    DatasetDict

    When loading without split=, returns a DatasetDict:

    -
    +
    ds_dict = load_dataset("path/to/data/", MySample)
     
     # Access splits
    @@ -767,7 +616,7 @@ 

    DatasetDict

    Explicit Data Files

    Override automatic detection with data_files:

    -
    +
    # Single pattern
     ds = load_dataset(
         "path/to/",
    @@ -796,7 +645,7 @@ 

    Explicit Data Files

    Streaming Mode

    The streaming parameter signals intent for streaming mode:

    -
    +
    # Mark as streaming
     ds_dict = load_dataset("path/to/data.tar", MySample, streaming=True)
     
    @@ -821,7 +670,7 @@ 

    Streaming Mode

    Auto Type Resolution

    When using index lookup, the sample type can be resolved automatically:

    -
    +
    from atdata.local import LocalIndex
     
     index = LocalIndex()
    @@ -835,7 +684,7 @@ 

    Auto Type Resolution<

    Error Handling

    -
    +
    try:
         ds = load_dataset("path/to/data.tar", MySample, split="train")
     except FileNotFoundError:
    @@ -851,7 +700,7 @@ 

    Error Handling

    Complete Example

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -905,19 +754,6 @@ 

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -560,15 +410,14 @@

    On this page

  • Complete Workflow Example
  • Related
  • - +
    - -
    +
    -

    Local Storage

    +

    Local Storage

    @@ -603,7 +452,7 @@

    Overview

    LocalIndex

    The index tracks datasets in Redis:

    -
    +
    from atdata.local import LocalIndex
     
     # Default connection (localhost:6379)
    @@ -619,7 +468,7 @@ 

    LocalIndex

    Adding Entries

    -
    +
    import atdata
     from numpy.typing import NDArray
     
    @@ -644,7 +493,7 @@ 

    Adding Entries

    Listing and Retrieving

    -
    +
    # Iterate all entries
     for entry in index.entries:
         print(f"{entry.name}: {entry.cid}")
    @@ -676,7 +525,7 @@ 

    Repo (Deprecated)

    The Repo class combines S3 storage with Redis indexing:

    -
    +
    from atdata.local import Repo
     
     # From credentials file
    @@ -696,7 +545,7 @@ 

    Repo (Deprecated)

    )

    Preferred approach - Use LocalIndex with S3DataStore:

    -
    +
    from atdata.local import LocalIndex, S3DataStore
     
     store = S3DataStore(
    @@ -734,7 +583,7 @@ 

    Credentials File F

    Inserting Datasets

    -
    +
    import webdataset as wds
     import numpy as np
     
    @@ -764,7 +613,7 @@ 

    Inserting Datasets

    Insert Options

    -
    +
    entry, ds = repo.insert(
         dataset,
         name="my-dataset",
    @@ -778,7 +627,7 @@ 

    Insert Options

    LocalDatasetEntry

    Index entries provide content-addressable identification:

    -
    +
    entry = index.get_entry_by_name("my-dataset")
     
     # Core properties (IndexEntry protocol)
    @@ -811,7 +660,7 @@ 

    LocalDatasetEntry

    Schema Storage

    Schemas can be stored and retrieved from the index:

    -
    +
    # Publish a schema
     schema_ref = index.publish_schema(
         ImageSample,
    @@ -842,7 +691,7 @@ 

    Schema Storage

    S3DataStore

    For direct S3 operations without Redis indexing:

    -
    +
    from atdata.local import S3DataStore
     
     store = S3DataStore(
    @@ -864,7 +713,7 @@ 

    S3DataStore

    Complete Workflow Example

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -932,19 +781,6 @@ 

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -564,15 +414,14 @@

    On this page

  • Best Practices
  • Related
  • - +
    - -
    +
    -

    Packable Samples

    +

    Packable Samples

    @@ -598,7 +447,7 @@

    Packable Samples

    The @packable Decorator

    The recommended way to define a sample type is with the @packable decorator:

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -620,7 +469,7 @@ 

    The @packable

    Supported Field Types

    Primitives

    -
    +
    @atdata.packable
     class PrimitiveSample:
         name: str
    @@ -633,7 +482,7 @@ 

    Primitives

    NumPy Arrays

    Fields annotated as NDArray are automatically converted:

    -
    +
    @atdata.packable
     class ArraySample:
         features: NDArray          # Required array
    @@ -655,7 +504,7 @@ 

    NumPy Arrays

    Lists

    -
    +
    @atdata.packable
     class ListSample:
         tags: list[str]
    @@ -667,7 +516,7 @@ 

    Lists

    Serialization

    Packing to Bytes

    -
    +
    sample = ImageSample(
         image=np.random.rand(224, 224, 3).astype(np.float32),
         label="cat",
    @@ -681,7 +530,7 @@ 

    Packing to Bytes

    Unpacking from Bytes

    -
    +
    # Deserialize from bytes
     restored = ImageSample.from_bytes(packed_bytes)
     
    @@ -693,12 +542,12 @@ 

    Unpacking from Bytes<

    WebDataset Format

    The as_wds property returns a dict ready for WebDataset:

    -
    +
    wds_dict = sample.as_wds
     # {'__key__': '1234...', 'msgpack': b'...'}

    Write samples to a tar file:

    -
    +
    import webdataset as wds
     
     with wds.writer.TarWriter("data-000000.tar") as sink:
    @@ -711,7 +560,7 @@ 

    WebDataset Format

    Direct Inheritance (Alternative)

    You can also inherit directly from PackableSample:

    -
    +
    from dataclasses import dataclass
     
     @dataclass
    @@ -726,7 +575,7 @@ 

    How It Works

    Serialization Flow

    - +
      @@ -749,7 +598,7 @@

      Serialization Flow

      The _ensure_good() Method

      This method runs automatically after construction and handles NDArray conversion:

      -
      +
      def _ensure_good(self):
           for field in dataclasses.fields(self):
               if _is_possibly_ndarray_type(field.type):
      @@ -765,7 +614,7 @@ 

      Best Practices

      -
      +
      @atdata.packable
       class GoodSample:
           features: NDArray           # Clear type annotation
      @@ -775,7 +624,7 @@ 

      Best Practices

      -
      +
      @atdata.packable
       class BadSample:
           # DON'T: Nested dataclasses not supported
      @@ -804,19 +653,6 @@ 

      Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -552,15 +402,14 @@

    On this page

  • Requirements
  • Related
  • - +
    - -
    +
    -

    Promotion Workflow

    +

    Promotion Workflow

    @@ -594,7 +443,7 @@

    Overview

    Basic Usage

    -
    +
    from atdata.local import LocalIndex
     from atdata.atmosphere import AtmosphereClient
     from atdata.promote import promote_to_atmosphere
    @@ -614,7 +463,7 @@ 

    Basic Usage

    With Metadata

    -
    +
    at_uri = promote_to_atmosphere(
         entry,
         local_index,
    @@ -629,7 +478,7 @@ 

    With Metadata

    Schema Deduplication

    The promotion workflow automatically checks for existing schemas:

    -
    +
    # First promotion: publishes schema
     uri1 = promote_to_atmosphere(entry1, local_index, client)
     
    @@ -645,11 +494,11 @@ 

    Schema Deduplication<

    Data Storage Options

    - +

    By default, promotion keeps the original data URLs:

    -
    +
    # Data stays in original S3 location
     at_uri = promote_to_atmosphere(entry, local_index, client)
    @@ -662,7 +511,7 @@

    Data Storage Options<

    To copy data to a different storage location:

    -
    +
    from atdata.local import S3DataStore
     
     # Create new data store
    @@ -690,7 +539,7 @@ 

    Data Storage Options<

    Complete Workflow Example

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -761,7 +610,7 @@ 

    Complete Workflo

    Error Handling

    -
    +
    try:
         at_uri = promote_to_atmosphere(entry, local_index, client)
     except KeyError as e:
    @@ -795,19 +644,6 @@ 

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -571,15 +421,14 @@

    On this page

  • Complete Example
  • Related
  • - +
    - -
    +
    -

    Protocols

    +

    Protocols

    @@ -615,7 +464,7 @@

    Overview

    IndexEntry Protocol

    Represents a dataset entry in any index:

    -
    +
    from atdata._protocols import IndexEntry
     
     def process_entry(entry: IndexEntry) -> None:
    @@ -669,7 +518,7 @@ 

    Implementations

    AbstractIndex Protocol

    Defines operations for managing schemas and datasets:

    -
    +
    from atdata._protocols import AbstractIndex
     
     def list_all_datasets(index: AbstractIndex) -> None:
    @@ -679,7 +528,7 @@ 

    AbstractIndex Proto

    Dataset Operations

    -
    +
    # Insert a dataset
     entry = index.insert_dataset(
         dataset,
    @@ -697,7 +546,7 @@ 

    Dataset Operations

    Schema Operations

    -
    +
    # Publish a schema
     schema_ref = index.publish_schema(
         MySample,
    @@ -728,7 +577,7 @@ 

    Implementations

    AbstractDataStore Protocol

    Abstracts over different storage backends:

    -
    +
    from atdata._protocols import AbstractDataStore
     
     def write_dataset(store: AbstractDataStore, dataset) -> list[str]:
    @@ -738,7 +587,7 @@ 

    AbstractDataSto

    Methods

    -
    +
    # Write dataset shards
     urls = store.write_shards(
         dataset,
    @@ -765,7 +614,7 @@ 

    Implementations

    DataSource Protocol

    Abstracts over different data source backends for streaming dataset shards:

    -
    +
    from atdata._protocols import DataSource
     
     def load_from_source(source: DataSource) -> None:
    @@ -778,7 +627,7 @@ 

    DataSource Protocol

    Methods

    -
    +
    # Get list of shard identifiers
     shard_ids = source.shard_list  # ['data-000000.tar', 'data-000001.tar', ...]
     
    @@ -801,7 +650,7 @@ 

    Implementations

    Creating Custom Data Sources

    Implement the DataSource protocol for custom backends:

    -
    +
    from typing import Iterator, IO
     from atdata._protocols import DataSource
     
    @@ -839,7 +688,7 @@ 

    Creating Cust

    Using Protocols for Polymorphism

    Write code that works with any backend:

    -
    +
    from atdata._protocols import AbstractIndex, IndexEntry
     from atdata import Dataset
     
    @@ -910,7 +759,7 @@ 

    Schema Reference

    Type Checking

    Protocols are runtime-checkable:

    -
    +
    from atdata._protocols import IndexEntry, AbstractIndex
     
     # Check if object implements protocol
    @@ -924,7 +773,7 @@ 

    Type Checking

    Complete Example

    -
    +
    import atdata
     from atdata.local import LocalIndex, S3DataStore
     from atdata.atmosphere import AtmosphereClient, AtmosphereIndex
    @@ -986,19 +835,6 @@ 

    Related

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -564,15 +414,14 @@

    On this page

  • Getting Help
  • - +
    - -
    +
    -

    Troubleshooting & FAQ

    +

    Troubleshooting & FAQ

    @@ -837,19 +686,6 @@

    Getting Help

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -557,15 +407,14 @@

    On this page

  • Relationship to AT Protocol URIs
  • Legacy Format
  • - +
    - -
    +
    -

    URI Specification

    +

    URI Specification

    @@ -685,7 +534,7 @@

    Version Specifiers

    Examples

    Local Development

    -
    +
    from atdata.local import Index
     
     index = Index()
    @@ -704,7 +553,7 @@ 

    Local Development

    Atmosphere (ATProto Federation)

    -
    +
    from atdata.atmosphere import Client
     
     client = Client()
    @@ -758,19 +607,6 @@ 

    Legacy Format

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -564,15 +414,14 @@

    On this page

  • The Full Picture
  • Next Steps
  • - +
    - -
    +
    -

    Atmosphere Publishing

    +

    Atmosphere Publishing

    @@ -658,7 +507,7 @@

    Prerequisites

    Setup

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -678,7 +527,7 @@ 

    Setup

    Define Sample Types

    -
    +
    @atdata.packable
     class ImageSample:
         """A sample containing image data with metadata."""
    @@ -697,7 +546,7 @@ 

    Define Sample Types

    Type Introspection

    See what information is available from a PackableSample type:

    -
    +
    from dataclasses import fields, is_dataclass
     
     print(f"Sample type: {ImageSample.__name__}")
    @@ -732,7 +581,7 @@ 

    AT URI Parsing

    Understanding AT URIs is essential for working with atmosphere datasets, as they’re how you reference schemas, datasets, and lenses.

    ATProto records are identified by AT URIs:

    -
    +
    uris = [
         "at://did:plc:abc123/ac.foundation.dataset.sampleSchema/xyz789",
         "at://alice.bsky.social/ac.foundation.dataset.record/my-dataset",
    @@ -750,7 +599,7 @@ 

    AT URI Parsing

    Authentication

    The AtmosphereClient handles ATProto authentication. When you authenticate, you’re proving ownership of your decentralized identity (DID), which gives you permission to create and modify records in your Personal Data Server (PDS).

    Connect to ATProto:

    -
    +
    client = AtmosphereClient()
     client.login("your.handle.social", "your-app-password")
     
    @@ -761,7 +610,7 @@ 

    Authentication

    Publish a Schema

    When you publish a schema to ATProto, it becomes a public, immutable record that others can reference. The schema CID ensures that anyone can verify they’re using exactly the same type definition you published.

    -
    +
    schema_publisher = SchemaPublisher(client)
     schema_uri = schema_publisher.publish(
         ImageSample,
    @@ -774,7 +623,7 @@ 

    Publish a Schema

    List Your Schemas

    -
    +
    schema_loader = SchemaLoader(client)
     schemas = schema_loader.list_all(limit=10)
     print(f"Found {len(schemas)} schema(s)")
    @@ -787,7 +636,7 @@ 

    List Your Schemas

    Publish a Dataset

    With External URLs

    -
    +
    dataset_publisher = DatasetPublisher(client)
     dataset_uri = dataset_publisher.publish_with_urls(
         urls=["s3://example-bucket/demo-data-{000000..000009}.tar"],
    @@ -809,7 +658,7 @@ 

    With PDS
  • Federated replication: Relays can mirror your blobs for availability
  • For fully decentralized storage, use PDSBlobStore to store dataset shards directly as ATProto blobs in your PDS:

    -
    +
    # Create store and index with blob storage
     store = PDSBlobStore(client)
     index = AtmosphereIndex(client, data_store=store)
    @@ -853,7 +702,7 @@ 

    With PDS

    Use BlobSource to stream directly from PDS blobs:

    -
    +
    # Create source from the blob URLs
     source = store.create_source(entry.data_urls)
     
    @@ -874,7 +723,7 @@ 

    With PDS

    With External URLs

    For larger datasets that exceed PDS blob limits, or when you already have data in object storage, you can publish a dataset record that references external URLs. The ATProto record serves as the index entry while the actual data lives elsewhere.

    For larger datasets or when using existing object storage:

    -
    +
    dataset_publisher = DatasetPublisher(client)
     dataset_uri = dataset_publisher.publish_with_urls(
         urls=["s3://example-bucket/demo-data-{000000..000009}.tar"],
    @@ -890,7 +739,7 @@ 

    With External URLs

    List and Load Datasets

    -
    +
    dataset_loader = DatasetLoader(client)
     datasets = dataset_loader.list_all(limit=10)
     print(f"Found {len(datasets)} dataset(s)")
    @@ -905,7 +754,7 @@ 

    List and Load Datas

    Load a Dataset

    -
    +
    # Check storage type
     storage_type = dataset_loader.get_storage_type(str(blob_dataset_uri))
     print(f"Storage type: {storage_type}")
    @@ -933,7 +782,7 @@ 

    Complete Publ

    Notice how similar this is to the local workflow—the same sample types and patterns, just with a different storage backend.

    This example shows the recommended workflow using PDSBlobStore for fully decentralized storage:

    -
    +
    # 1. Define and create samples
     @atdata.packable
     class FeatureSample:
    @@ -1061,19 +910,6 @@ 

    Next Steps

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -557,15 +407,14 @@

    On this page

  • What You’ve Learned
  • Next Steps
  • - +
    - -
    +
    -

    Local Workflow

    +

    Local Workflow

    @@ -644,7 +493,7 @@

    Prerequisites

    Setup

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -654,7 +503,7 @@ 

    Setup

    Define Sample Types

    -
    +
    @atdata.packable
     class TrainingSample:
         """A sample containing features and label for training."""
    @@ -678,7 +527,7 @@ 

    LocalDatasetEntry

    CIDs are computed from the entry’s schema reference and data URLs, so the same logical dataset will have the same CID regardless of where it’s stored.

    Create entries with content-addressable CIDs:

    -
    +
    # Create an entry manually
     entry = LocalDatasetEntry(
         _name="my-dataset",
    @@ -711,7 +560,7 @@ 

    LocalDatasetEntry

    LocalIndex

    The LocalIndex is your team’s dataset registry. It implements the AbstractIndex protocol, meaning code written against LocalIndex will also work with AtmosphereIndex when you’re ready for federated sharing.

    The index tracks datasets in Redis:

    -
    +
    from redis import Redis
     
     # Connect to Redis
    @@ -724,7 +573,7 @@ 

    LocalIndex

    Schema Management

    Schema publishing is how you ensure type consistency across your team. When you publish a schema, atdata stores the complete type definition (field names, types, metadata) so anyone can reconstruct the Python class from just the schema reference.

    This enables a powerful workflow: share a dataset by sharing its name, and consumers can dynamically reconstruct the sample type without having the original Python code.

    -
    +
    # Publish a schema
     schema_ref = index.publish_schema(TrainingSample, version="1.0.0")
     print(f"Published schema: {schema_ref}")
    @@ -753,7 +602,7 @@ 

    S3DataStore

    The data store handles uploading tar shards and creating signed URLs for streaming access.

    For direct S3 operations:

    -
    +
    creds = {
         "AWS_ENDPOINT": "http://localhost:9000",
         "AWS_ACCESS_KEY_ID": "minioadmin",
    @@ -779,7 +628,7 @@ 

    Complete Index Wor

    The index composition pattern (LocalIndex(data_store=S3DataStore(...))) is deliberate—it separates the concern of “where is metadata?” from “where is data?”, making it easy to swap storage backends.

    Use LocalIndex with S3DataStore to store datasets with S3 storage and Redis indexing:

    -
    +
    # 1. Create sample data
     samples = [
         TrainingSample(
    @@ -829,7 +678,7 @@ 

    Complete Index Wor

    Using load_dataset with Index

    The load_dataset() function provides a HuggingFace-style API that abstracts away the details of where data lives. When you pass an index, it can resolve @local/ prefixed paths to the actual data URLs and apply the correct credentials automatically.

    The load_dataset() function supports index lookup:

    -
    +
    from atdata import load_dataset
     
     # Load from local index
    @@ -903,19 +752,6 @@ 

    Next Steps

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -558,15 +408,14 @@

    On this page

  • The Complete Journey
  • Next Steps
  • - +
    - -
    +
    -

    Promotion Workflow

    +

    Promotion Workflow

    @@ -621,7 +470,7 @@

    Overview

    Setup

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -634,7 +483,7 @@ 

    Setup

    Prepare a Local Dataset

    First, set up a dataset in local storage:

    -
    +
    # 1. Define sample type
     @atdata.packable
     class ExperimentSample:
    @@ -684,7 +533,7 @@ 

    Prepare a Local Da

    Basic Promotion

    Promote the dataset to ATProto:

    -
    +
    # Connect to atmosphere
     client = AtmosphereClient()
     client.login("myhandle.bsky.social", "app-password")
    @@ -697,7 +546,7 @@ 

    Basic Promotion

    Promotion with Metadata

    Add description, tags, and license:

    -
    +
    at_uri = promote_to_atmosphere(
         local_entry,
         local_index,
    @@ -713,7 +562,7 @@ 

    Promotion with Met

    Schema Deduplication

    The promotion workflow automatically checks for existing schemas:

    -
    +
    from atdata.promote import _find_existing_schema
     
     # Check if schema already exists
    @@ -725,7 +574,7 @@ 

    Schema Deduplication< print("No existing schema found, will publish new one")

    When you promote multiple datasets with the same sample type:

    -
    +
    # First promotion: publishes schema
     uri1 = promote_to_atmosphere(entry1, local_index, client)
     
    @@ -736,11 +585,11 @@ 

    Schema Deduplication<

    Data Migration Options

    - +

    By default, promotion keeps the original data URLs:

    -
    +
    # Data stays in original S3 location
     at_uri = promote_to_atmosphere(local_entry, local_index, client)
    @@ -753,7 +602,7 @@

    Data Migration Opti

    To copy data to a different storage location:

    -
    +
    from atdata.local import S3DataStore
     
     # Create new data store
    @@ -783,7 +632,7 @@ 

    Data Migration Opti

    Verify on Atmosphere

    After promotion, verify the dataset is accessible:

    -
    +
    from atdata.atmosphere import AtmosphereIndex
     
     atm_index = AtmosphereIndex(client)
    @@ -804,7 +653,7 @@ 

    Verify on Atmosphere<

    Error Handling

    -
    +
    try:
         at_uri = promote_to_atmosphere(local_entry, local_index, client)
     except KeyError as e:
    @@ -828,7 +677,7 @@ 

    Requirements Checkl

    Complete Workflow

    -
    +
    # Complete local-to-atmosphere workflow
     import numpy as np
     from numpy.typing import NDArray
    @@ -944,19 +793,6 @@ 

    Next Steps

    - - - + - - - + + + - +
    @@ -291,10 +146,6 @@
    -
    @@ -396,9 +246,9 @@ - - +
    @@ -553,15 +403,14 @@

    On this page

  • What You’ve Learned
  • Next Steps
  • - +
    - -
    +
    -

    Quick Start

    +

    Quick Start

    @@ -606,7 +455,7 @@

    Define a Sample Type<
  • Round-trip fidelity: Data survives serialization without loss
  • Use the @packable decorator to create a typed sample:

    -
    +
    import numpy as np
     from numpy.typing import NDArray
     import atdata
    @@ -627,7 +476,7 @@ 

    Define a Sample Type<

    Create Sample Instances

    -
    +
    # Create a single sample
     sample = ImageSample(
         image=np.random.rand(224, 224, 3).astype(np.float32),
    @@ -655,7 +504,7 @@ 

    Write a Dataset

    The as_wds property on your sample provides the dictionary format WebDataset expects:

    Use WebDataset’s TarWriter to create dataset files:

    -
    +
    import webdataset as wds
     
     # Create 100 samples
    @@ -686,7 +535,7 @@ 

    Load and Iterate

    This eliminates boilerplate collation code and works automatically with any PackableSample type.

    Create a typed Dataset and iterate with batching:

    -
    +
    # Load dataset with type
     dataset = atdata.Dataset[ImageSample]("my-dataset-000000.tar")
     
    @@ -713,7 +562,7 @@ 

    Shuffled Iteration

    This approach balances randomness with streaming efficiency—you get well-shuffled data without needing random access to the entire dataset.

    For training, use shuffled iteration:

    -
    +
    for batch in dataset.shuffled(batch_size=32):
         # Samples are shuffled at shard and sample level
         images = batch.image
    @@ -734,7 +583,7 @@ 

    Use Le
  • Derived features: Compute fields on-the-fly during iteration
  • View datasets through different schemas:

    -
    +
    # Define a simplified view type
     @atdata.packable
     class SimplifiedSample:
    @@ -814,19 +663,6 @@ 

    Next Steps

    type[_T]type[PackableSample] A new dataclass that inherits from PackableSample with the same
    type[_T]type[PackableSample] name and annotations as the original class. The class satisfies the
    type[_T]type[PackableSample] Packable protocol and can be used with Type[Packable] signatures.