diff --git a/.chainlink/issues.db b/.chainlink/issues.db
index 12b36ae..edbde07 100644
Binary files a/.chainlink/issues.db and b/.chainlink/issues.db differ
diff --git a/.github/workflows/uv-test.yml b/.github/workflows/uv-test.yml
index cc8b74d..2e18832 100644
--- a/.github/workflows/uv-test.yml
+++ b/.github/workflows/uv-test.yml
@@ -4,13 +4,11 @@ on:
push:
branches:
- main
- - release/*
pull_request:
- branches:
- - main
permissions:
contents: read
+ actions: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@@ -77,3 +75,60 @@ jobs:
with:
fail_ci_if_error: false
token: ${{ secrets.CODECOV_TOKEN }}
+
+ benchmark:
+ name: Benchmarks
+ runs-on: ubuntu-latest
+ needs: [lint]
+ permissions:
+ contents: write
+ actions: write
+ steps:
+ - uses: actions/checkout@v5
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.14"
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v6
+ with:
+ enable-cache: true
+
+ - name: Install just
+ uses: extractions/setup-just@v2
+
+ - name: Install the project
+ run: uv sync --locked --all-extras --dev
+
+ - name: Start Redis
+ uses: supercharge/redis-github-action@1.8.1
+ with:
+ redis-version: 7
+
+ - name: Run benchmarks
+ run: just bench
+
+ - name: Copy report to docs
+ run: |
+ mkdir -p docs/benchmarks
+ cp .bench/report.html docs/benchmarks/index.html
+
+ - name: Commit updated benchmark docs
+ if: github.event_name == 'push'
+ run: |
+ git config user.name "github-actions[bot]"
+ git config user.email "github-actions[bot]@users.noreply.github.com"
+ git add docs/benchmarks/index.html
+ git diff --cached --quiet || git commit -m "docs: update benchmark report [skip ci]"
+ git push
+
+ - name: Upload benchmark report
+ uses: actions/upload-artifact@v4
+ if: always()
+ with:
+ name: benchmark-report
+ path: |
+ .bench/report.html
+ .bench/*.json
diff --git a/.gitignore b/.gitignore
index 75ac226..97301a8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -52,6 +52,9 @@ MANIFEST
pip-log.txt
pip-delete-this-directory.txt
+# Benchmark results
+.bench/
+
# Unit test / coverage reports
htmlcov/
.tox/
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 3e03efa..56064c5 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -5,15 +5,19 @@
"atproto",
"creds",
"dtype",
+ "fastparquet",
"getattr",
"hgetall",
"hset",
+ "libipld",
"maxcount",
"minioadmin",
"msgpack",
"ndarray",
"NSID",
"ormsgpack",
+ "psycopg",
+ "pydantic",
"pypi",
"pyproject",
"pytest",
@@ -24,6 +28,8 @@
"schemamodels",
"shardlists",
"tariterators",
+ "tqdm",
+ "typer",
"unpackb",
"webdataset"
],
diff --git a/CHANGELOG.md b/CHANGELOG.md
index feeea7e..312972d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -25,6 +25,111 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- **Comprehensive integration test suite**: 593 tests covering E2E flows, error handling, edge cases
### Changed
+- Investigate upload-artifact not finding benchmark output (#512)
+- Fix duplicate CI runs for push+PR overlap (#511)
+- Scope contents:write permission to benchmark job only (#510)
+- Add benchmark docs auto-commit to CI workflow (#509)
+- Submit PR for v0.3.0b1 release to upstream/main (#508)
+- Implement GH#39: Production hardening (observability, error handling, testing infra) (#504)
+- Add pluggable structured logging via atdata.configure_logging (#507)
+- Add PartialFailureError and shard-level error handling to Dataset.map (#506)
+- Add atdata.testing module with mock clients, fixtures, and helpers (#505)
+- Fix CI linting failures (20 ruff errors) (#503)
+- Adversarial review: Post-benchmark suite assessment (#494)
+- Remove redundant protocol docstrings that restate signatures (#500)
+- Add missing unit tests for _type_utils.py (#499)
+- Strengthen weak assertions (assert X is not None → value checks) (#498)
+- Trim verbose exception constructor docstrings (#501)
+- Analyze benchmark results for performance improvement opportunities (#502)
+- Consolidate remaining duplicate sample types in test files (#497)
+- Remove dead code: _repo_legacy.py legacy UUID field, unused imports (#496)
+- Trim verbose docstrings in dataset.py and _index.py (#495)
+- Benchmark report: replace mean/stddev with median/IQR, add per-sample columns (#492)
+- Add parameter descriptions to benchmark suite with automatic report introspection (#491)
+- HTML benchmark reports with CI integration (#487)
+- Add bench + render step to CI on highest Python version only (#490)
+- Update justfile bench commands to export JSON and render (#489)
+- Create render_report.py script to convert JSON to HTML (#488)
+- Increase test coverage for low-coverage modules (#480)
+- Add providers/_postgres.py tests (mock-based) (#485)
+- Add _stub_manager.py tests (#484)
+- Add manifest/_query.py tests (#483)
+- Add repository.py tests (#482)
+- Add CLI tests (cli/__init__, diagnose, local, preview, schema) (#481)
+- Check test coverage for CLI utils (#479)
+- Add performance benchmark suite for atdata (#471)
+- Verify benchmarks run (#478)
+- Update pyproject.toml and justfile (#477)
+- Create bench_atmosphere.py (#476)
+- Create bench_query.py (#475)
+- Create bench_dataset_io.py (#474)
+- Create bench_index_providers.py (#473)
+- Create benchmarks/conftest.py with shared fixtures (#472)
+- Add per-shard manifest and query system (GH #35) (#462)
+- Write unit and integration tests (#470)
+- Integrate manifest into write path and Dataset.query() (#469)
+- Implement QueryExecutor and SampleLocation (#468)
+- Implement ManifestWriter (JSON + parquet) (#467)
+- Implement ManifestBuilder (#465)
+- Implement ShardManifest data model (#466)
+- Implement aggregate collectors (categorical, numeric, set) (#464)
+- Implement ManifestField annotation and resolve_manifest_fields() (#463)
+- Migrate type annotations from PackableSample to Packable protocol (#461)
+- Remove LocalIndex factory — consolidate to Index (#460)
+- Split local.py monolith into local/ package (#452)
+- Verify tests and lint pass (#459)
+- Create __init__.py re-export facade and delete local.py (#458)
+- Create _repo_legacy.py with deprecated Repo class (#457)
+- Create _index.py with Index class and LocalIndex factory (#456)
+- Create _s3.py with S3DataStore and S3 helpers (#455)
+- Create _schema.py with schema models and helpers (#454)
+- Create _entry.py with LocalDatasetEntry and constants (#453)
+- Migrate CLI from argparse to typer (#449)
+- Investigate test failures (#450)
+- Fix ensure_stub receiving LocalSchemaRecord instead of dict (#451)
+- GH#38: Developer experience improvements (#437)
+- CLI: atdata preview command (#440)
+- CLI: atdata schema show/diff commands (#439)
+- CLI: atdata inspect command (#438)
+- Dataset.__len__ and Dataset.select() for sample count and indexed access (#447)
+- Dataset.to_pandas() and Dataset.to_dict() export methods (#446)
+- Dataset.filter() and Dataset.map() streaming transforms (#445)
+- Dataset.get(key) for keyed sample access (#442)
+- Dataset.describe() summary statistics (#444)
+- Dataset.schema property and column_names (#443)
+- Dataset.head(n) and Dataset.__iter__ convenience methods (#441)
+- Custom exception hierarchy with actionable error messages (#448)
+- Adversarial review: Post-Repository consolidation assessment (#430)
+- Remove backwards-compat dict-access methods from SchemaField and LocalSchemaRecord (#436)
+- Add missing test coverage for Repository prefix routing edge cases and error paths (#435)
+- Trim over-verbose docstrings in local.py module/class level (#434)
+- Fix formally incorrect test assertions (batch_size, CID, brace notation) (#433)
+- Consolidate duplicate test sample types across test files into conftest.py (#432)
+- Consolidate duplicate entry-creation logic in Index (add_entry vs _insert_dataset_to_provider) (#431)
+- Switch default Index provider from Redis to SQLite (#429)
+- Consolidated Index with Repository system (#424)
+- Phase 4: Deprecate AtmosphereIndex, update exports (#428)
+- Phase 3: Default Index singleton and load_dataset integration (#427)
+- Phase 2: Extend Index with repos/atmosphere params and prefix routing (#426)
+- Phase 1: Create Repository dataclass and _AtmosphereBackend in repository.py (#425)
+- Adversarial review: Post-IndexProvider pluggable storage assessment (#417)
+- Convert TODO comments to tracked issues or remove (#422)
+- Remove deprecated shard_list property references from docstrings (#421)
+- Replace bare except in _stub_manager.py and cli/local.py with specific exceptions (#423)
+- Tighten generic pytest.raises(Exception) to specific exception types in tests (#420)
+- Replace assert statements with ValueError in production code (#419)
+- Consolidate duplicated _parse_semver into _type_utils.py (#418)
+- feat: Add SQLite/PostgreSQL index providers (GH #42) (#409)
+- Update documentation and public API exports (#416)
+- Add tests for all providers (#415)
+- Refactor Index class to accept provider parameter (#414)
+- Implement PostgresIndexProvider (#413)
+- Implement SqliteIndexProvider (#412)
+- Implement RedisIndexProvider (extract from Index class) (#411)
+- Define IndexProvider protocol in _protocols.py (#410)
+- Add just lint command to justfile (#408)
+- Add SQLite/PostgreSQL providers for LocalIndex (in addition to Redis) (#407)
+- Fix type hints for @atdata.packable decorator to show PackableSample methods (#406)
- Review GitHub workflows and recommend CI improvements (#405)
- Fix type signatures for Dataset.ordered and Dataset.shuffled (GH#28) (#404)
- Investigate quartodoc Example section rendering - missing CSS classes on pre/code tags (#401)
diff --git a/CLAUDE.md b/CLAUDE.md
index 349c268..6b096b4 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -46,8 +46,10 @@ uv build
Development tasks are managed with [just](https://github.com/casey/just), a command runner. Available commands:
```bash
-# Build documentation (runs quartodoc + quarto)
-just docs
+just test # Run all tests with coverage
+just test tests/test_dataset.py # Run specific test file
+just lint # Run ruff check + format check
+just docs # Build documentation (runs quartodoc + quarto)
```
The `justfile` is in the project root. Add new dev tasks there rather than creating shell scripts.
diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/benchmarks/bench_atmosphere.py b/benchmarks/bench_atmosphere.py
new file mode 100644
index 0000000..cab470e
--- /dev/null
+++ b/benchmarks/bench_atmosphere.py
@@ -0,0 +1,220 @@
+"""Performance benchmarks for remote storage backends.
+
+Covers S3DataStore (via moto mock) and Atmosphere/ATProto (network-gated).
+S3 benchmarks use moto for reproducible local measurement.
+Atmosphere benchmarks are marked ``network`` and skip unless a live PDS is available.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import numpy as np
+import pytest
+from moto import mock_aws
+
+import atdata
+
+from .conftest import (
+ IMAGE_SHAPE,
+ BenchBasicSample,
+ BenchManifestSample,
+ BenchNumpySample,
+ generate_basic_samples,
+ generate_manifest_samples,
+ generate_numpy_samples,
+ write_tar,
+)
+
+
+# =============================================================================
+# S3 Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def mock_s3():
+ """Provide mock S3 environment using moto."""
+ with mock_aws():
+ import boto3
+
+ creds = {
+ "AWS_ACCESS_KEY_ID": "testing",
+ "AWS_SECRET_ACCESS_KEY": "testing",
+ }
+ s3_client = boto3.client(
+ "s3",
+ aws_access_key_id=creds["AWS_ACCESS_KEY_ID"],
+ aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"],
+ region_name="us-east-1",
+ )
+ bucket_name = "bench-bucket"
+ s3_client.create_bucket(Bucket=bucket_name)
+ yield {
+ "credentials": creds,
+ "bucket": bucket_name,
+ }
+
+
+def _make_s3_store(mock_s3_env):
+ from atdata.local._s3 import S3DataStore
+
+ return S3DataStore(
+ credentials=mock_s3_env["credentials"],
+ bucket=mock_s3_env["bucket"],
+ )
+
+
+def _make_source_dataset(tmp_path, samples):
+ """Create a local dataset from samples for use as S3 write source."""
+ tar_path = write_tar(tmp_path / "source-000000.tar", samples)
+ sample_type = type(samples[0])
+ return atdata.Dataset[sample_type](url=str(tar_path))
+
+
+# =============================================================================
+# S3 Write Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_s3
+@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
+@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning")
+class TestS3WriteBenchmarks:
+ """S3 shard writing benchmarks via moto mock."""
+
+ PARAM_LABELS = {"n": "samples per shard"}
+
+ @pytest.mark.parametrize("n", [100, 500], ids=["100", "500"])
+ def test_s3_write_shards(self, benchmark, tmp_path, mock_s3, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_basic_samples(n)
+ ds = _make_source_dataset(tmp_path, samples)
+ store = _make_s3_store(mock_s3)
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ store.write_shards(ds, prefix=f"bench/basic-{n}-{idx}")
+
+ benchmark(_write)
+
+ def test_s3_write_with_manifest(self, benchmark, tmp_path, mock_s3):
+ benchmark.extra_info["n_samples"] = 200
+ samples = generate_manifest_samples(200)
+ ds = _make_source_dataset(tmp_path, samples)
+ store = _make_s3_store(mock_s3)
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ store.write_shards(
+ ds, prefix=f"bench/manifest-{idx}", manifest=True,
+ cache_local=True,
+ )
+
+ benchmark(_write)
+
+ def test_s3_write_cache_local(self, benchmark, tmp_path, mock_s3):
+ benchmark.extra_info["n_samples"] = 200
+ samples = generate_basic_samples(200)
+ ds = _make_source_dataset(tmp_path, samples)
+ store = _make_s3_store(mock_s3)
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ store.write_shards(
+ ds, prefix=f"bench/cache-{idx}", cache_local=True
+ )
+
+ benchmark(_write)
+
+ def test_s3_write_direct(self, benchmark, tmp_path, mock_s3):
+ benchmark.extra_info["n_samples"] = 200
+ samples = generate_basic_samples(200)
+ ds = _make_source_dataset(tmp_path, samples)
+ store = _make_s3_store(mock_s3)
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ store.write_shards(
+ ds, prefix=f"bench/direct-{idx}", cache_local=False
+ )
+
+ benchmark(_write)
+
+ def test_s3_write_numpy(self, benchmark, tmp_path, mock_s3):
+ benchmark.extra_info["n_samples"] = 100
+ samples = generate_numpy_samples(100)
+ ds = _make_source_dataset(tmp_path, samples)
+ store = _make_s3_store(mock_s3)
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ store.write_shards(ds, prefix=f"bench/numpy-{idx}")
+
+ benchmark(_write)
+
+
+# =============================================================================
+# Atmosphere Benchmarks (network-gated)
+# =============================================================================
+
+
+@pytest.mark.network
+class TestAtmosphereBenchmarks:
+ """Atmosphere/ATProto benchmarks. Require live PDS access.
+
+ Run with: just bench -m network
+ """
+
+ def test_atmosphere_publish_dataset(self, benchmark, tmp_path):
+ """End-to-end dataset publish to Atmosphere."""
+ import os
+
+ handle = os.environ.get("ATDATA_BENCH_ATP_HANDLE")
+ password = os.environ.get("ATDATA_BENCH_ATP_PASSWORD")
+ if not handle or not password:
+ pytest.skip("ATDATA_BENCH_ATP_HANDLE/PASSWORD not set")
+
+ from atdata.atmosphere.client import AtmosphereClient
+
+ client = AtmosphereClient(handle=handle, password=password)
+
+ samples = generate_basic_samples(10)
+ tar_path = write_tar(tmp_path / "atmo-000000.tar", samples)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+
+ counter = [0]
+
+ def _publish():
+ idx = counter[0]
+ counter[0] += 1
+ from atdata.atmosphere.records import DatasetPublisher
+
+ publisher = DatasetPublisher(client)
+ publisher.publish(ds, name=f"bench-atmo-{idx}")
+
+ benchmark(_publish)
+
+ def test_atmosphere_resolve_dataset(self, benchmark):
+ """Resolve a dataset record from Atmosphere (read-only, anonymous)."""
+ import os
+
+ ref = os.environ.get("ATDATA_BENCH_ATP_DATASET_REF")
+ if not ref:
+ pytest.skip("ATDATA_BENCH_ATP_DATASET_REF not set")
+
+ from atdata.local._index import Index
+
+ index = Index()
+
+ benchmark(index.get_dataset, ref)
diff --git a/benchmarks/bench_dataset_io.py b/benchmarks/bench_dataset_io.py
new file mode 100644
index 0000000..8668b09
--- /dev/null
+++ b/benchmarks/bench_dataset_io.py
@@ -0,0 +1,293 @@
+"""Performance benchmarks for dataset read/write operations.
+
+Measures shard writing throughput, read iteration speed, serialization
+overhead, and round-trip performance for basic and numpy sample types.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+import webdataset as wds
+
+import atdata
+
+from .conftest import (
+ IMAGE_DTYPE,
+ IMAGE_SHAPE,
+ TSERIES_DTYPE,
+ TSERIES_SHAPE,
+ BenchBasicSample,
+ BenchManifestSample,
+ BenchNumpySample,
+ generate_basic_samples,
+ generate_manifest_samples,
+ generate_numpy_samples,
+ write_tar,
+ write_tar_with_manifest,
+)
+
+
+# =============================================================================
+# Write Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_io
+class TestShardWriteBenchmarks:
+ """Shard writing throughput benchmarks."""
+
+ PARAM_LABELS = {"n": "samples per shard"}
+
+ @pytest.mark.parametrize("n", [100, 1000, 10000], ids=["100", "1k", "10k"])
+ def test_write_basic_shard(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_basic_samples(n)
+
+ def _write():
+ tar_path = tmp_path / f"basic-{n}.tar"
+ write_tar(tar_path, samples)
+
+ benchmark(_write)
+
+ @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"])
+ def test_write_numpy_shard(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_numpy_samples(n)
+
+ def _write():
+ tar_path = tmp_path / f"numpy-{n}.tar"
+ write_tar(tar_path, samples)
+
+ benchmark(_write)
+
+ def test_write_large_numpy_shard(self, benchmark, tmp_path):
+ benchmark.extra_info["n_samples"] = 10
+ samples = generate_numpy_samples(10, shape=TSERIES_SHAPE, dtype=TSERIES_DTYPE)
+
+ def _write():
+ tar_path = tmp_path / "numpy-large.tar"
+ write_tar(tar_path, samples)
+
+ benchmark(_write)
+
+ @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"])
+ def test_write_with_manifest(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_manifest_samples(n)
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ tar_path = tmp_path / f"manifest-{n}-{idx}.tar"
+ write_tar_with_manifest(tar_path, samples, BenchManifestSample)
+
+ benchmark(_write)
+
+ def test_write_multi_shard(self, benchmark, tmp_path):
+ benchmark.extra_info["n_samples"] = 10000
+ samples = generate_basic_samples(10000)
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ out_dir = tmp_path / f"multi-{idx}"
+ out_dir.mkdir(exist_ok=True)
+ pattern = str(out_dir / "shard-%06d.tar")
+ with wds.writer.ShardWriter(pattern, maxcount=1000) as sink:
+ for sample in samples:
+ sink.write(sample.as_wds)
+
+ benchmark(_write)
+
+
+# =============================================================================
+# Read Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_io
+class TestShardReadBenchmarks:
+ """Shard reading and iteration benchmarks."""
+
+ PARAM_LABELS = {"n": "samples in dataset", "batch_size": "samples per batch"}
+
+ @pytest.mark.parametrize("n", [100, 1000, 10000], ids=["100", "1k", "10k"])
+ def test_read_ordered(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_basic_samples(n)
+ tar_path = write_tar(tmp_path / f"read-ordered-{n}.tar", samples)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+
+ def _read():
+ count = 0
+ for _ in ds.ordered():
+ count += 1
+ return count
+
+ result = benchmark(_read)
+ assert result == n
+
+ @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"])
+ def test_read_shuffled(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_basic_samples(n)
+ tar_path = write_tar(tmp_path / f"read-shuffled-{n}.tar", samples)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+
+ def _read():
+ count = 0
+ for _ in ds.shuffled():
+ count += 1
+ return count
+
+ result = benchmark(_read)
+ assert result == n
+
+ @pytest.mark.parametrize(
+ "batch_size", [32, 128], ids=["batch32", "batch128"]
+ )
+ def test_read_batched(self, benchmark, tmp_path, batch_size):
+ n = 1000
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_basic_samples(n)
+ tar_path = write_tar(tmp_path / f"read-batched-{batch_size}.tar", samples)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+
+ def _read():
+ count = 0
+ for batch in ds.ordered(batch_size=batch_size):
+ count += 1
+ return count
+
+ benchmark(_read)
+
+ @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"])
+ def test_read_numpy_ordered(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_numpy_samples(n)
+ tar_path = write_tar(tmp_path / f"read-numpy-{n}.tar", samples)
+ ds = atdata.Dataset[BenchNumpySample](url=str(tar_path))
+
+ def _read():
+ count = 0
+ for _ in ds.ordered():
+ count += 1
+ return count
+
+ result = benchmark(_read)
+ assert result == n
+
+
+# =============================================================================
+# Serialization Benchmarks (No I/O)
+# =============================================================================
+
+
+@pytest.mark.bench_serial
+class TestSerializationBenchmarks:
+ """Pure serialization/deserialization without disk I/O."""
+
+ def test_serialize_basic_sample(self, benchmark):
+ sample = BenchBasicSample(name="bench_sample", value=42)
+ benchmark(lambda: sample.packed)
+
+ def test_deserialize_basic_sample(self, benchmark):
+ sample = BenchBasicSample(name="bench_sample", value=42)
+ packed = sample.packed
+ benchmark(BenchBasicSample.from_bytes, packed)
+
+ def test_serialize_numpy_sample(self, benchmark):
+ sample = BenchNumpySample(
+ data=np.random.randint(0, 256, size=IMAGE_SHAPE, dtype=IMAGE_DTYPE),
+ label="bench",
+ )
+ benchmark(lambda: sample.packed)
+
+ def test_deserialize_numpy_sample(self, benchmark):
+ sample = BenchNumpySample(
+ data=np.random.randint(0, 256, size=IMAGE_SHAPE, dtype=IMAGE_DTYPE),
+ label="bench",
+ )
+ packed = sample.packed
+ benchmark(BenchNumpySample.from_bytes, packed)
+
+ def test_serialize_large_numpy(self, benchmark):
+ sample = BenchNumpySample(
+ data=np.random.randn(*TSERIES_SHAPE).astype(TSERIES_DTYPE),
+ label="large",
+ )
+ benchmark(lambda: sample.packed)
+
+ def test_deserialize_large_numpy(self, benchmark):
+ sample = BenchNumpySample(
+ data=np.random.randn(*TSERIES_SHAPE).astype(TSERIES_DTYPE),
+ label="large",
+ )
+ packed = sample.packed
+ benchmark(BenchNumpySample.from_bytes, packed)
+
+ def test_as_wds_basic(self, benchmark):
+ sample = BenchBasicSample(name="bench_sample", value=42)
+ benchmark(lambda: sample.as_wds)
+
+ def test_as_wds_numpy(self, benchmark):
+ sample = BenchNumpySample(
+ data=np.random.randint(0, 256, size=IMAGE_SHAPE, dtype=IMAGE_DTYPE),
+ label="bench",
+ )
+ benchmark(lambda: sample.as_wds)
+
+
+# =============================================================================
+# Round-Trip Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_io
+class TestRoundTripBenchmarks:
+ """Full write-then-read round-trip benchmarks."""
+
+ PARAM_LABELS = {"n": "samples round-tripped"}
+
+ @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"])
+ def test_roundtrip_basic(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_basic_samples(n)
+ counter = [0]
+
+ def _roundtrip():
+ idx = counter[0]
+ counter[0] += 1
+ tar_path = tmp_path / f"rt-basic-{n}-{idx}.tar"
+ write_tar(tar_path, samples)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+ count = 0
+ for _ in ds.ordered():
+ count += 1
+ return count
+
+ result = benchmark(_roundtrip)
+ assert result == n
+
+ @pytest.mark.parametrize("n", [100, 500], ids=["100", "500"])
+ def test_roundtrip_numpy(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_numpy_samples(n)
+ counter = [0]
+
+ def _roundtrip():
+ idx = counter[0]
+ counter[0] += 1
+ tar_path = tmp_path / f"rt-numpy-{n}-{idx}.tar"
+ write_tar(tar_path, samples)
+ ds = atdata.Dataset[BenchNumpySample](url=str(tar_path))
+ count = 0
+ for _ in ds.ordered():
+ count += 1
+ return count
+
+ result = benchmark(_roundtrip)
+ assert result == n
diff --git a/benchmarks/bench_index_providers.py b/benchmarks/bench_index_providers.py
new file mode 100644
index 0000000..4d0aeea
--- /dev/null
+++ b/benchmarks/bench_index_providers.py
@@ -0,0 +1,215 @@
+"""Performance benchmarks for index provider operations.
+
+Measures read/write latency and throughput for SQLite, Redis, and PostgreSQL
+providers. Each benchmark skips gracefully if the backend is unavailable.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from atdata.local._entry import LocalDatasetEntry
+
+from .conftest import BenchBasicSample, generate_basic_samples
+
+
+# =============================================================================
+# Helpers
+# =============================================================================
+
+
+def _make_entry(i: int) -> LocalDatasetEntry:
+ return LocalDatasetEntry(
+ name=f"bench_dataset_{i:06d}",
+ schema_ref=f"atdata://local/sampleSchema/BenchBasicSample@1.0.0",
+ data_urls=[f"/tmp/bench/data-{i:06d}.tar"],
+ metadata={"index": i, "split": "train"},
+ )
+
+
+def _make_schema_json(name: str, version: str) -> str:
+ return (
+ f'{{"name": "{name}", "version": "{version}", '
+ f'"fields": [{{"name": "x", "type": "int"}}]}}'
+ )
+
+
+def _prepopulate_entries(provider, n: int) -> list[LocalDatasetEntry]:
+ entries = [_make_entry(i) for i in range(n)]
+ for entry in entries:
+ provider.store_entry(entry)
+ return entries
+
+
+def _prepopulate_schemas(provider, name: str, n: int) -> list[str]:
+ versions = []
+ for i in range(n):
+ version = f"1.0.{i}"
+ provider.store_schema(name, version, _make_schema_json(name, version))
+ versions.append(version)
+ return versions
+
+
+# =============================================================================
+# Write Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_index
+class TestProviderWriteBenchmarks:
+ """Write operation benchmarks across all providers."""
+
+ PARAM_LABELS = {"n": "entries to store", "any_provider": "storage backend"}
+
+ def test_store_single_entry(self, benchmark, any_provider):
+ entry = _make_entry(0)
+ benchmark(any_provider.store_entry, entry)
+
+ @pytest.mark.parametrize("n", [10, 100, 1000], ids=["10", "100", "1k"])
+ def test_store_entries_bulk(self, benchmark, any_provider, n):
+ entries = [_make_entry(i) for i in range(n)]
+
+ def _store_all():
+ for entry in entries:
+ any_provider.store_entry(entry)
+
+ benchmark(_store_all)
+
+ def test_store_schema(self, benchmark, any_provider):
+ benchmark(
+ any_provider.store_schema,
+ "BenchSample",
+ "1.0.0",
+ _make_schema_json("BenchSample", "1.0.0"),
+ )
+
+ @pytest.mark.parametrize("n", [10, 50], ids=["10v", "50v"])
+ def test_store_schema_versions(self, benchmark, any_provider, n):
+ def _store_versions():
+ for i in range(n):
+ v = f"1.0.{i}"
+ any_provider.store_schema(
+ "BenchVersioned", v, _make_schema_json("BenchVersioned", v)
+ )
+
+ benchmark(_store_versions)
+
+
+# =============================================================================
+# Read Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_index
+class TestProviderReadBenchmarks:
+ """Read operation benchmarks across all providers."""
+
+ PARAM_LABELS = {"n": "entries in index", "any_provider": "storage backend"}
+
+ def test_get_entry_by_name(self, benchmark, any_provider):
+ entries = _prepopulate_entries(any_provider, 100)
+ target = entries[50]
+ benchmark(any_provider.get_entry_by_name, target.name)
+
+ def test_get_entry_by_cid(self, benchmark, any_provider):
+ entries = _prepopulate_entries(any_provider, 100)
+ target = entries[50]
+ benchmark(any_provider.get_entry_by_cid, target.cid)
+
+ @pytest.mark.parametrize("n", [10, 100, 1000], ids=["10", "100", "1k"])
+ def test_iter_entries(self, benchmark, any_provider, n):
+ _prepopulate_entries(any_provider, n)
+ benchmark(lambda: list(any_provider.iter_entries()))
+
+ def test_get_schema_json(self, benchmark, any_provider):
+ any_provider.store_schema(
+ "BenchRead", "1.0.0", _make_schema_json("BenchRead", "1.0.0")
+ )
+ benchmark(any_provider.get_schema_json, "BenchRead", "1.0.0")
+
+ @pytest.mark.parametrize("n", [5, 20, 50], ids=["5v", "20v", "50v"])
+ def test_find_latest_version(self, benchmark, any_provider, n):
+ _prepopulate_schemas(any_provider, "BenchLatest", n)
+ benchmark(any_provider.find_latest_version, "BenchLatest")
+
+ def test_iter_schemas(self, benchmark, any_provider):
+ _prepopulate_schemas(any_provider, "BenchIterSchema", 20)
+ benchmark(lambda: list(any_provider.iter_schemas()))
+
+
+# =============================================================================
+# Index-Level Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_index
+class TestIndexBenchmarks:
+ """Benchmarks through the full Index API."""
+
+ def test_index_insert_dataset(self, benchmark, tmp_path, sqlite_provider):
+ import atdata
+ from atdata.local._index import Index
+
+ samples = generate_basic_samples(10)
+ tar_path = tmp_path / "idx-bench-000000.tar"
+ from .conftest import write_tar
+
+ write_tar(tar_path, samples)
+
+ index = Index(provider=sqlite_provider, atmosphere=None)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+
+ counter = [0]
+
+ def _insert():
+ name = f"bench_ds_{counter[0]:06d}"
+ counter[0] += 1
+ index.insert_dataset(ds, name=name)
+
+ benchmark(_insert)
+
+ def test_index_get_dataset(self, benchmark, tmp_path, sqlite_provider):
+ import atdata
+ from atdata.local._index import Index
+
+ samples = generate_basic_samples(10)
+ tar_path = tmp_path / "idx-get-000000.tar"
+ from .conftest import write_tar
+
+ write_tar(tar_path, samples)
+
+ index = Index(provider=sqlite_provider, atmosphere=None)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+ index.insert_dataset(ds, name="bench_lookup_target")
+
+ benchmark(index.get_dataset, "bench_lookup_target")
+
+ def test_index_list_datasets(self, benchmark, tmp_path, sqlite_provider):
+ import atdata
+ from atdata.local._index import Index
+
+ samples = generate_basic_samples(5)
+ tar_path = tmp_path / "idx-list-000000.tar"
+ from .conftest import write_tar
+
+ write_tar(tar_path, samples)
+
+ index = Index(provider=sqlite_provider, atmosphere=None)
+ ds = atdata.Dataset[BenchBasicSample](url=str(tar_path))
+ for i in range(100):
+ index.insert_dataset(ds, name=f"bench_list_{i:04d}")
+
+ benchmark(index.list_datasets)
+
+ def test_index_publish_schema(self, benchmark, sqlite_provider):
+ from atdata.local._index import Index
+
+ index = Index(provider=sqlite_provider, atmosphere=None)
+ counter = [0]
+
+ def _publish():
+ v = f"1.0.{counter[0]}"
+ counter[0] += 1
+ index.publish_schema(BenchBasicSample, version=v)
+
+ benchmark(_publish)
diff --git a/benchmarks/bench_query.py b/benchmarks/bench_query.py
new file mode 100644
index 0000000..a5eccec
--- /dev/null
+++ b/benchmarks/bench_query.py
@@ -0,0 +1,278 @@
+"""Performance benchmarks for the manifest query system.
+
+Measures query execution speed across different predicate types,
+manifest loading performance, and scaling behavior with increasing
+shard counts and sample sizes.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from atdata.manifest import ManifestBuilder, ManifestWriter, QueryExecutor, ShardManifest
+
+from .conftest import (
+ BenchManifestSample,
+ generate_manifest_samples,
+ create_sharded_dataset,
+ write_tar_with_manifest,
+)
+
+
+# =============================================================================
+# Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def query_dataset_small(tmp_path):
+ """Small query dataset: 2 shards x 50 samples = 100 total."""
+ samples = generate_manifest_samples(100)
+ create_sharded_dataset(
+ tmp_path, samples, 50, BenchManifestSample, with_manifests=True
+ )
+ executor = QueryExecutor.from_directory(tmp_path)
+ return executor, tmp_path
+
+
+@pytest.fixture
+def query_dataset_medium(tmp_path):
+ """Medium query dataset: 10 shards x 100 samples = 1000 total."""
+ samples = generate_manifest_samples(1000)
+ create_sharded_dataset(
+ tmp_path, samples, 100, BenchManifestSample, with_manifests=True
+ )
+ executor = QueryExecutor.from_directory(tmp_path)
+ return executor, tmp_path
+
+
+@pytest.fixture
+def query_dataset_large(tmp_path):
+ """Large query dataset: 10 shards x 1000 samples = 10000 total."""
+ samples = generate_manifest_samples(10000)
+ create_sharded_dataset(
+ tmp_path, samples, 1000, BenchManifestSample, with_manifests=True
+ )
+ executor = QueryExecutor.from_directory(tmp_path)
+ return executor, tmp_path
+
+
+# =============================================================================
+# Query Predicate Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_query
+class TestQueryPredicateBenchmarks:
+ """Benchmark different query predicate types on a medium dataset."""
+
+ def test_query_simple_equality(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ benchmark(executor.query, where=lambda df: df["label"] == "dog")
+
+ def test_query_numeric_range(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ benchmark(executor.query, where=lambda df: df["confidence"] > 0.8)
+
+ def test_query_combined(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ benchmark(
+ executor.query,
+ where=lambda df: (df["label"] == "dog") & (df["confidence"] > 0.8),
+ )
+
+ def test_query_isin(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ benchmark(
+ executor.query,
+ where=lambda df: df["label"].isin(["dog", "cat"]),
+ )
+
+ def test_query_no_results(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ benchmark(
+ executor.query,
+ where=lambda df: df["confidence"] > 999.0,
+ )
+
+ def test_query_all_results(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ benchmark(
+ executor.query,
+ where=lambda df: df["confidence"] >= 0.0,
+ )
+
+
+# =============================================================================
+# Query Result Iteration Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_query
+class TestQueryIterationBenchmarks:
+ """Benchmark iterating through query results to access sample locations."""
+
+ def test_iterate_equality_results(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ # Pre-run to capture result count for per-result normalization
+ n = len(executor.query(where=lambda df: df["label"] == "dog"))
+ benchmark.extra_info["n_samples"] = n
+
+ def _query_and_iterate():
+ results = executor.query(where=lambda df: df["label"] == "dog")
+ keys = [loc.key for loc in results]
+ return len(keys)
+
+ count = benchmark(_query_and_iterate)
+ assert count == n
+
+ def test_iterate_range_results(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ n = len(executor.query(where=lambda df: df["confidence"] > 0.5))
+ benchmark.extra_info["n_samples"] = n
+
+ def _query_and_iterate():
+ results = executor.query(where=lambda df: df["confidence"] > 0.5)
+ by_shard: dict[str, list[int]] = {}
+ for loc in results:
+ by_shard.setdefault(loc.shard, []).append(loc.offset)
+ return sum(len(v) for v in by_shard.values())
+
+ count = benchmark(_query_and_iterate)
+ assert count == n
+
+ def test_iterate_large_result_set(self, benchmark, query_dataset_large):
+ executor, _ = query_dataset_large
+ benchmark.extra_info["n_samples"] = 10000
+
+ def _query_and_iterate():
+ results = executor.query(where=lambda df: df["confidence"] >= 0.0)
+ keys = [loc.key for loc in results]
+ return len(keys)
+
+ count = benchmark(_query_and_iterate)
+ assert count == 10000
+
+
+# =============================================================================
+# Scale Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_query
+class TestQueryScaleBenchmarks:
+ """Benchmark query performance at different scales."""
+
+ def test_query_small(self, benchmark, query_dataset_small):
+ executor, _ = query_dataset_small
+ benchmark(executor.query, where=lambda df: df["confidence"] > 0.5)
+
+ def test_query_medium(self, benchmark, query_dataset_medium):
+ executor, _ = query_dataset_medium
+ benchmark(executor.query, where=lambda df: df["confidence"] > 0.5)
+
+ def test_query_large(self, benchmark, query_dataset_large):
+ executor, _ = query_dataset_large
+ benchmark(executor.query, where=lambda df: df["confidence"] > 0.5)
+
+
+# =============================================================================
+# Manifest Loading Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_query
+class TestManifestLoadBenchmarks:
+ """Benchmark manifest loading from disk."""
+
+ PARAM_LABELS = {"n_shards": "number of shards (100 samples each)"}
+
+ @pytest.mark.parametrize("n_shards", [2, 5, 10, 20], ids=["2s", "5s", "10s", "20s"])
+ def test_load_from_directory(self, benchmark, tmp_path, n_shards):
+ total = n_shards * 100
+ samples = generate_manifest_samples(total)
+ create_sharded_dataset(
+ tmp_path, samples, 100, BenchManifestSample, with_manifests=True
+ )
+ benchmark(QueryExecutor.from_directory, tmp_path)
+
+ @pytest.mark.parametrize("n_shards", [2, 5, 10], ids=["2s", "5s", "10s"])
+ def test_load_from_shard_urls(self, benchmark, tmp_path, n_shards):
+ total = n_shards * 100
+ samples = generate_manifest_samples(total)
+ tar_paths = create_sharded_dataset(
+ tmp_path, samples, 100, BenchManifestSample, with_manifests=True
+ )
+ shard_urls = [str(p) for p in tar_paths]
+ benchmark(QueryExecutor.from_shard_urls, shard_urls)
+
+
+# =============================================================================
+# Manifest Build Benchmarks
+# =============================================================================
+
+
+@pytest.mark.bench_query
+class TestManifestBuildBenchmarks:
+ """Benchmark manifest construction from samples."""
+
+ PARAM_LABELS = {"n": "samples in manifest"}
+
+ @pytest.mark.parametrize("n", [100, 1000, 5000], ids=["100", "1k", "5k"])
+ def test_manifest_build(self, benchmark, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_manifest_samples(n)
+
+ def _build():
+ builder = ManifestBuilder(
+ sample_type=BenchManifestSample,
+ shard_id="bench-shard-000000",
+ )
+ offset = 0
+ for i, sample in enumerate(samples):
+ packed_size = len(sample.packed)
+ builder.add_sample(
+ key=f"sample_{i:06d}",
+ offset=offset,
+ size=packed_size,
+ sample=sample,
+ )
+ offset += 512 + packed_size
+ return builder.build()
+
+ manifest = benchmark(_build)
+ assert manifest.num_samples == n
+
+ @pytest.mark.parametrize("n", [100, 1000], ids=["100", "1k"])
+ def test_manifest_write(self, benchmark, tmp_path, n):
+ benchmark.extra_info["n_samples"] = n
+ samples = generate_manifest_samples(n)
+
+ builder = ManifestBuilder(
+ sample_type=BenchManifestSample,
+ shard_id="bench-write-000000",
+ )
+ offset = 0
+ for i, sample in enumerate(samples):
+ packed_size = len(sample.packed)
+ builder.add_sample(
+ key=f"sample_{i:06d}",
+ offset=offset,
+ size=packed_size,
+ sample=sample,
+ )
+ offset += 512 + packed_size
+ manifest = builder.build()
+
+ counter = [0]
+
+ def _write():
+ idx = counter[0]
+ counter[0] += 1
+ writer = ManifestWriter(tmp_path / f"mw-{idx:06d}")
+ writer.write(manifest)
+
+ benchmark(_write)
diff --git a/benchmarks/conftest.py b/benchmarks/conftest.py
new file mode 100644
index 0000000..2c7f84a
--- /dev/null
+++ b/benchmarks/conftest.py
@@ -0,0 +1,345 @@
+"""Shared fixtures and helpers for performance benchmarks."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Annotated, Any
+
+import numpy as np
+import pytest
+import webdataset as wds
+from numpy.typing import NDArray, DTypeLike
+
+import atdata
+from atdata.manifest import ManifestBuilder, ManifestField, ManifestWriter
+
+
+# =============================================================================
+# Benchmark Sample Types
+# =============================================================================
+
+
+@atdata.packable
+class BenchBasicSample:
+ """Lightweight sample for throughput benchmarks."""
+
+ name: str
+ value: int
+
+
+@atdata.packable
+class BenchNumpySample:
+ """Sample with NDArray for array serialization benchmarks."""
+
+ data: NDArray
+ label: str
+
+
+@atdata.packable
+class BenchManifestSample:
+ """Sample with manifest-annotated fields for query benchmarks."""
+
+ data: NDArray
+ label: Annotated[str, ManifestField("categorical")]
+ confidence: Annotated[float, ManifestField("numeric")]
+ tags: Annotated[list[str], ManifestField("set")]
+
+
+# =============================================================================
+# Benchmark Constants
+# =============================================================================
+
+# Standard image: 3-channel 224x224 uint8 (ImageNet-style)
+IMAGE_SHAPE = (3, 224, 224)
+IMAGE_DTYPE = np.uint8
+
+# Large biological timeseries: 1024x1024 spatial x 600 frames, float32
+TSERIES_SHAPE = (1024, 1024, 60)
+TSERIES_DTYPE = np.float32
+
+# Small array for manifest/overhead benchmarks (keeps manifests fast)
+MANIFEST_ARRAY_SHAPE = (4, 4)
+MANIFEST_ARRAY_DTYPE = np.float32
+
+
+# =============================================================================
+# Sample Generators
+# =============================================================================
+
+LABELS = ["dog", "cat", "bird", "fish", "horse"]
+TAG_POOLS = [
+ ["outdoor", "day"],
+ ["indoor"],
+ ["outdoor", "night"],
+ ["underwater"],
+ ["field", "day"],
+]
+
+
+def generate_basic_samples(n: int) -> list[BenchBasicSample]:
+ return [BenchBasicSample(name=f"sample_{i:06d}", value=i) for i in range(n)]
+
+
+def generate_numpy_samples(
+ n: int,
+ shape: tuple[int, ...] = IMAGE_SHAPE,
+ dtype: np.dtype = IMAGE_DTYPE,
+) -> list[BenchNumpySample]:
+ return [
+ BenchNumpySample(
+ data=np.random.randint(0, 256, size=shape, dtype=dtype)
+ if np.issubdtype(dtype, np.integer)
+ else np.random.randn(*shape).astype(dtype),
+ label=f"array_{i:06d}",
+ )
+ for i in range(n)
+ ]
+
+
+def generate_manifest_samples(
+ n: int, shape: tuple[int, ...] = (4, 4)
+) -> list[BenchManifestSample]:
+ return [
+ BenchManifestSample(
+ data=np.random.randn(*shape).astype(np.float32),
+ label=LABELS[i % len(LABELS)],
+ confidence=0.1 + 0.9 * (i % 100) / 100.0,
+ tags=TAG_POOLS[i % len(TAG_POOLS)],
+ )
+ for i in range(n)
+ ]
+
+
+# =============================================================================
+# Tar/Dataset Helpers
+# =============================================================================
+
+
+def write_tar(tar_path: Path, samples: list) -> Path:
+ """Write samples to a WebDataset tar file."""
+ tar_path.parent.mkdir(parents=True, exist_ok=True)
+ with wds.writer.TarWriter(str(tar_path)) as writer:
+ for sample in samples:
+ writer.write(sample.as_wds)
+ return tar_path
+
+
+def write_tar_with_manifest(
+ tar_path: Path,
+ samples: list,
+ sample_type: type,
+) -> tuple[Path, Path, Path]:
+ """Write samples to a tar file and generate companion manifest files.
+
+ Returns:
+ Tuple of (tar_path, json_path, parquet_path).
+ """
+ tar_path.parent.mkdir(parents=True, exist_ok=True)
+ shard_name = tar_path.stem
+ shard_id = str(tar_path.parent / shard_name)
+
+ builder = ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
+
+ offset = 0
+ with wds.writer.TarWriter(str(tar_path)) as writer:
+ for sample in samples:
+ wds_dict = sample.as_wds
+ writer.write(wds_dict)
+ packed_size = len(wds_dict.get("msgpack", b""))
+ builder.add_sample(
+ key=wds_dict["__key__"],
+ offset=offset,
+ size=packed_size,
+ sample=sample,
+ )
+ offset += 512 + packed_size + (512 - packed_size % 512) % 512
+
+ manifest = builder.build()
+ manifest_writer = ManifestWriter(tar_path.parent / shard_name)
+ json_path, parquet_path = manifest_writer.write(manifest)
+
+ return tar_path, json_path, parquet_path
+
+
+def create_sharded_dataset(
+ base_dir: Path,
+ samples: list,
+ samples_per_shard: int,
+ sample_type: type,
+ with_manifests: bool = False,
+) -> list[Path]:
+ """Split samples across multiple shards. Returns list of tar paths."""
+ tar_paths: list[Path] = []
+ for shard_idx in range(0, len(samples), samples_per_shard):
+ chunk = samples[shard_idx : shard_idx + samples_per_shard]
+ shard_name = f"data-{shard_idx // samples_per_shard:06d}"
+ tar_path = base_dir / f"{shard_name}.tar"
+
+ if with_manifests:
+ write_tar_with_manifest(tar_path, chunk, sample_type)
+ else:
+ write_tar(tar_path, chunk)
+
+ tar_paths.append(tar_path)
+
+ return tar_paths
+
+
+# =============================================================================
+# Provider Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def sqlite_provider(tmp_path):
+ """Fresh SQLite provider in a temp directory."""
+ from atdata.providers._sqlite import SqliteProvider
+
+ provider = SqliteProvider(path=tmp_path / "bench.db")
+ yield provider
+ provider.close()
+
+
+@pytest.fixture
+def redis_provider():
+ """Real Redis provider, skip if unavailable."""
+ from redis import Redis
+
+ try:
+ conn = Redis()
+ conn.ping()
+ except Exception:
+ pytest.skip("Redis server not available")
+
+ from atdata.providers._redis import RedisProvider
+
+ provider = RedisProvider(conn)
+
+ # Clean up benchmark keys before/after
+ def _cleanup():
+ for pattern in ("LocalDatasetEntry:bench_*", "LocalSchema:Bench*"):
+ for key in conn.scan_iter(match=pattern):
+ conn.delete(key)
+
+ _cleanup()
+ yield provider
+ _cleanup()
+ provider.close()
+
+
+@pytest.fixture
+def postgres_provider():
+ """Real PostgreSQL provider, skip if unavailable."""
+ try:
+ import psycopg # noqa: F401
+ except ImportError:
+ pytest.skip("psycopg not installed")
+
+ import os
+
+ dsn = os.environ.get("ATDATA_BENCH_POSTGRES_DSN")
+ if not dsn:
+ pytest.skip("ATDATA_BENCH_POSTGRES_DSN not set")
+
+ from atdata.providers._postgres import PostgresProvider
+
+ provider = PostgresProvider(dsn=dsn)
+ yield provider
+ provider.close()
+
+
+@pytest.fixture(
+ params=["sqlite", "redis", "postgres"],
+ ids=["sqlite", "redis", "postgres"],
+)
+def any_provider(request, tmp_path):
+ """Parametrized fixture that yields each available provider."""
+ backend = request.param
+
+ if backend == "sqlite":
+ from atdata.providers._sqlite import SqliteProvider
+
+ provider = SqliteProvider(path=tmp_path / "bench.db")
+ yield provider
+ provider.close()
+
+ elif backend == "redis":
+ from redis import Redis
+
+ try:
+ conn = Redis()
+ conn.ping()
+ except Exception:
+ pytest.skip("Redis server not available")
+
+ from atdata.providers._redis import RedisProvider
+
+ provider = RedisProvider(conn)
+
+ def _cleanup():
+ for pattern in ("LocalDatasetEntry:bench_*", "LocalSchema:Bench*"):
+ for key in conn.scan_iter(match=pattern):
+ conn.delete(key)
+
+ _cleanup()
+ yield provider
+ _cleanup()
+ provider.close()
+
+ elif backend == "postgres":
+ try:
+ import psycopg # noqa: F401
+ except ImportError:
+ pytest.skip("psycopg not installed")
+
+ import os
+
+ dsn = os.environ.get("ATDATA_BENCH_POSTGRES_DSN")
+ if not dsn:
+ pytest.skip("ATDATA_BENCH_POSTGRES_DSN not set")
+
+ from atdata.providers._postgres import PostgresProvider
+
+ provider = PostgresProvider(dsn=dsn)
+ yield provider
+ provider.close()
+
+
+# =============================================================================
+# Dataset Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def small_basic_dataset(tmp_path):
+ """100-sample basic dataset (1 shard)."""
+ samples = generate_basic_samples(100)
+ tar_path = write_tar(tmp_path / "small-000000.tar", samples)
+ return atdata.Dataset[BenchBasicSample](url=str(tar_path)), samples
+
+
+@pytest.fixture
+def medium_basic_dataset(tmp_path):
+ """1000-sample basic dataset (1 shard)."""
+ samples = generate_basic_samples(1000)
+ tar_path = write_tar(tmp_path / "medium-000000.tar", samples)
+ return atdata.Dataset[BenchBasicSample](url=str(tar_path)), samples
+
+
+@pytest.fixture
+def small_numpy_dataset(tmp_path):
+ """100-sample numpy dataset (1 shard, 10x10 arrays)."""
+ samples = generate_numpy_samples(100)
+ tar_path = write_tar(tmp_path / "numpy-000000.tar", samples)
+ return atdata.Dataset[BenchNumpySample](url=str(tar_path)), samples
+
+
+@pytest.fixture
+def manifest_dataset_small(tmp_path):
+ """100-sample manifest dataset with 2 shards."""
+ samples = generate_manifest_samples(100)
+ create_sharded_dataset(
+ tmp_path, samples, 50, BenchManifestSample, with_manifests=True
+ )
+ return tmp_path, samples
diff --git a/benchmarks/render_report.py b/benchmarks/render_report.py
new file mode 100644
index 0000000..9cc1cd4
--- /dev/null
+++ b/benchmarks/render_report.py
@@ -0,0 +1,462 @@
+"""Render pytest-benchmark JSON results into a standalone HTML report.
+
+Usage:
+ uv run python -m benchmarks.render_report results/*.json -o bench-report.html
+
+Reads one or more pytest-benchmark JSON files (one per group) and produces
+a single HTML page with grouped tables and test descriptions extracted from
+the benchmark source files.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import json
+import sys
+from dataclasses import dataclass, field
+from pathlib import Path
+
+import jinja2
+
+
+# =============================================================================
+# Docstring extraction
+# =============================================================================
+
+
+@dataclass
+class _SourceMeta:
+ """Metadata extracted from benchmark source files via AST."""
+
+ docstrings: dict[str, str] = field(default_factory=dict)
+ param_labels: dict[str, dict[str, str]] = field(default_factory=dict)
+
+
+def _extract_source_meta(bench_dir: Path) -> _SourceMeta:
+ """Walk benchmark .py files and extract class/method docstrings and PARAM_LABELS.
+
+ Returns a ``_SourceMeta`` with:
+ - ``docstrings``: mapping qualified names like
+ ``"TestSerializationBenchmarks"`` or
+ ``"TestSerializationBenchmarks.test_serialize_basic_sample"``
+ to their docstring text.
+ - ``param_labels``: mapping class names to their ``PARAM_LABELS`` dict
+ (e.g. ``{"n": "samples per shard"}``).
+ """
+ meta = _SourceMeta()
+ for py_file in sorted(bench_dir.glob("bench_*.py")):
+ tree = ast.parse(py_file.read_text())
+ for node in ast.walk(tree):
+ if isinstance(node, ast.ClassDef):
+ cls_doc = ast.get_docstring(node)
+ if cls_doc:
+ meta.docstrings[node.name] = cls_doc
+ for item in node.body:
+ if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ fn_doc = ast.get_docstring(item)
+ if fn_doc:
+ meta.docstrings[f"{node.name}.{item.name}"] = fn_doc
+ # Extract PARAM_LABELS = {...} class variable
+ if isinstance(item, ast.Assign):
+ for target in item.targets:
+ if (
+ isinstance(target, ast.Name)
+ and target.id == "PARAM_LABELS"
+ ):
+ try:
+ value = ast.literal_eval(item.value)
+ if isinstance(value, dict):
+ meta.param_labels[node.name] = value
+ except (ValueError, TypeError):
+ pass
+ return meta
+
+
+# =============================================================================
+# Data model
+# =============================================================================
+
+# Friendly names for each marker group
+GROUP_TITLES: dict[str, str] = {
+ "bench_serial": "Serialization (μs scale)",
+ "bench_index": "Index Providers (μs–ms scale)",
+ "bench_io": "Dataset I/O (ms scale)",
+ "bench_query": "Query System (ms scale)",
+ "bench_s3": "S3 Storage (ms+ scale)",
+}
+
+
+@dataclass
+class BenchRow:
+ name: str
+ fullname: str
+ description: str
+ raw_params: str
+ params_desc: str
+ median: float
+ iqr: float
+ ops: float
+ n_samples: int | None = None
+
+
+@dataclass
+class BenchGroup:
+ key: str
+ title: str
+ class_description: str
+ param_labels: dict[str, str] = field(default_factory=dict)
+ has_samples: bool = False
+ rows: list[BenchRow] = field(default_factory=list)
+
+
+def _format_time(seconds: float) -> str:
+ """Format seconds into a human-readable string with appropriate unit."""
+ if seconds < 1e-6:
+ return f"{seconds * 1e9:.1f} ns"
+ if seconds < 1e-3:
+ return f"{seconds * 1e6:.2f} μs"
+ if seconds < 1.0:
+ return f"{seconds * 1e3:.2f} ms"
+ return f"{seconds:.3f} s"
+
+
+def _format_ops(ops: float) -> str:
+ """Format operations per second with SI prefix."""
+ if ops >= 1e6:
+ return f"{ops / 1e6:.2f} Mops/s"
+ if ops >= 1e3:
+ return f"{ops / 1e3:.2f} Kops/s"
+ return f"{ops:.1f} ops/s"
+
+
+def _class_from_fullname(fullname: str) -> str:
+ """Extract class name from a pytest fullname like
+ ``benchmarks/bench_dataset_io.py::TestSerializationBenchmarks::test_foo``.
+ """
+ parts = fullname.split("::")
+ if len(parts) >= 2:
+ return parts[1]
+ return ""
+
+
+def _method_from_fullname(fullname: str) -> str:
+ """Extract method name from a pytest fullname."""
+ parts = fullname.split("::")
+ if len(parts) >= 3:
+ return parts[2]
+ return parts[-1]
+
+
+# =============================================================================
+# Build groups from JSON
+# =============================================================================
+
+
+def _format_params(
+ params_dict: dict | None,
+ param_labels: dict[str, str],
+) -> str:
+ """Format a benchmark's params dict using human-readable labels.
+
+ Given ``{"n": 1000}`` and labels ``{"n": "samples per shard"}``,
+ returns ``"1000 samples per shard"``.
+ """
+ if not params_dict:
+ return ""
+ parts: list[str] = []
+ for key, value in params_dict.items():
+ label = param_labels.get(key)
+ if label:
+ parts.append(f"{value} {label}")
+ else:
+ parts.append(f"{key}={value}")
+ return ", ".join(parts)
+
+
+def _build_groups(
+ json_paths: list[Path],
+ meta: _SourceMeta,
+) -> list[BenchGroup]:
+ """Load JSON files and assemble BenchGroup objects."""
+ # Each JSON file corresponds to one marker group. We infer the group key
+ # from the filename produced by the justfile (e.g. ``serial.json``).
+ groups: dict[str, BenchGroup] = {}
+
+ for path in sorted(json_paths):
+ data = json.loads(path.read_text())
+ benchmarks = data.get("benchmarks", [])
+ if not benchmarks:
+ continue
+
+ group_key = path.stem # e.g. "serial", "index", "io", "query", "s3"
+ marker_key = f"bench_{group_key}"
+ title = GROUP_TITLES.get(marker_key, group_key.title())
+
+ # Collect unique class descriptions and param labels for the group
+ class_names_seen: set[str] = set()
+ class_descs: list[str] = []
+ merged_param_labels: dict[str, str] = {}
+ rows: list[BenchRow] = []
+
+ for bench in benchmarks:
+ fullname = bench["fullname"]
+ stats = bench["stats"]
+ cls_name = _class_from_fullname(fullname)
+ method_name = _method_from_fullname(fullname)
+
+ if cls_name and cls_name not in class_names_seen:
+ class_names_seen.add(cls_name)
+ cls_doc = meta.docstrings.get(cls_name, "")
+ if cls_doc:
+ class_descs.append(cls_doc)
+ # Merge param labels from this class
+ cls_labels = meta.param_labels.get(cls_name, {})
+ for k, v in cls_labels.items():
+ if k not in merged_param_labels:
+ merged_param_labels[k] = v
+
+ # Build description: prefer method docstring, fall back to
+ # readable version of test name (strip bracket suffix)
+ base_method = method_name.split("[")[0]
+ qualified = f"{cls_name}.{base_method}" if cls_name else base_method
+ desc = meta.docstrings.get(qualified, "")
+ if not desc:
+ # Convert test_serialize_basic_sample -> Serialize basic sample
+ readable = base_method.removeprefix("test_").replace("_", " ").capitalize()
+ desc = readable
+
+ # Raw param ID for the top line (e.g. "sqlite-5v")
+ raw_param = bench.get("param") or ""
+ # Human-readable param description for the subtitle
+ params_dict = bench.get("params")
+ params_desc = _format_params(params_dict, merged_param_labels)
+
+ extra = bench.get("extra_info", {})
+ n_samples = extra.get("n_samples")
+ if n_samples is not None:
+ n_samples = int(n_samples)
+
+ rows.append(
+ BenchRow(
+ name=bench["name"],
+ fullname=fullname,
+ description=desc,
+ raw_params=str(raw_param) if raw_param else "",
+ params_desc=params_desc,
+ median=stats["median"],
+ iqr=stats["iqr"],
+ ops=stats["ops"],
+ n_samples=n_samples,
+ )
+ )
+
+ groups[group_key] = BenchGroup(
+ key=group_key,
+ title=title,
+ class_description="; ".join(class_descs) if class_descs else "",
+ param_labels=merged_param_labels,
+ has_samples=any(r.n_samples is not None for r in rows),
+ rows=rows,
+ )
+
+ return list(groups.values())
+
+
+# =============================================================================
+# HTML template
+# =============================================================================
+
+TEMPLATE = jinja2.Template(
+ """\
+
+
+
+
+
+atdata benchmark report
+
+
+
+atdata benchmark report
+Generated from pytest-benchmark JSON output
+
+{% if machine %}
+
+ {{ machine.node }} —
+ {{ machine.cpu.brand_raw }} ({{ machine.cpu.count }} cores) ·
+ Python {{ machine.python_version }} ·
+ {{ machine.system }} {{ machine.release }}
+ {% if commit %}
+ · {{ commit.branch }}@{{ commit.id[:8] }}{% if commit.dirty %} (dirty){% endif %}
+ {% endif %}
+
+{% endif %}
+
+{% for group in groups %}
+
+ {{ group.title }}
+ {% if group.class_description %}
+ {{ group.class_description }}
+ {% endif %}
+ {% if group.param_labels %}
+ Parameters: {% for key, label in group.param_labels.items() %}{{ key }} = {{ label }}{% if not loop.last %}, {% endif %}{% endfor %}
+ {% endif %}
+
+
+
+ Test
+ Median
+ IQR
+ OPS
+ {% if group.has_samples %}Med/sample
+ Samples/s {% endif %}
+
+
+
+ {% for row in group.rows %}
+
+
+ {% if "[" in row.name %}{{ row.name.split("[")[0] }} [{{ row.name.split("[")[1] }} {% else %}{{ row.name }} {% endif %}
+ {{ row.description }}{% if row.params_desc %} [{{ row.params_desc }}]{% endif %}
+
+ {{ fmt_time(row.median) }}
+ {{ fmt_time(row.iqr) }}
+ {{ fmt_ops(row.ops) }}
+ {% if group.has_samples %}{% if row.n_samples %}{{ fmt_time(row.median / row.n_samples) }}{% else %}—{% endif %}
+ {% if row.n_samples %}{{ fmt_ops(row.n_samples / row.median) }}{% else %}—{% endif %} {% endif %}
+
+ {% endfor %}
+
+
+
+{% endfor %}
+
+
+ Report generated by benchmarks/render_report.py from
+ {{ groups | length }} benchmark group{{ "s" if groups | length != 1 }},
+ {{ total_benchmarks }} total benchmarks.
+
+
+
+""",
+ undefined=jinja2.StrictUndefined,
+)
+
+
+# =============================================================================
+# Main
+# =============================================================================
+
+
+def render_html(json_paths: list[Path], bench_dir: Path) -> str:
+ """Render benchmark JSON files into an HTML string."""
+ meta = _extract_source_meta(bench_dir)
+ groups = _build_groups(json_paths, meta)
+
+ # Extract machine/commit info from the first JSON file
+ machine = None
+ commit = None
+ for path in json_paths:
+ data = json.loads(path.read_text())
+ if "machine_info" in data:
+ machine = data["machine_info"]
+ commit = data.get("commit_info")
+ break
+
+ total = sum(len(g.rows) for g in groups)
+
+ return TEMPLATE.render(
+ groups=groups,
+ machine=machine,
+ commit=commit,
+ total_benchmarks=total,
+ fmt_time=_format_time,
+ fmt_ops=_format_ops,
+ )
+
+
+def main(argv: list[str] | None = None) -> None:
+ parser = argparse.ArgumentParser(
+ description="Render pytest-benchmark JSON into HTML report",
+ )
+ parser.add_argument(
+ "json_files",
+ nargs="+",
+ type=Path,
+ help="One or more pytest-benchmark JSON files",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ type=Path,
+ default=Path("bench-report.html"),
+ help="Output HTML file (default: bench-report.html)",
+ )
+ parser.add_argument(
+ "--bench-dir",
+ type=Path,
+ default=Path(__file__).parent,
+ help="Directory containing bench_*.py files for docstring extraction",
+ )
+ args = parser.parse_args(argv)
+
+ existing = [p for p in args.json_files if p.exists()]
+ if not existing:
+ print(f"Error: no JSON files found: {args.json_files}", file=sys.stderr)
+ sys.exit(1)
+
+ html = render_html(existing, args.bench_dir)
+ args.output.write_text(html)
+ print(f"Wrote {args.output} ({len(html):,} bytes, from {len(existing)} JSON files)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/api/AbstractDataStore.html b/docs/api/AbstractDataStore.html
index b4b9393..4ceb07c 100644
--- a/docs/api/AbstractDataStore.html
+++ b/docs/api/AbstractDataStore.html
@@ -71,14 +71,10 @@
-
-
-
+
-
-
-
+
+
+
-
+