Skip to content

Commit 5c9cde8

Browse files
Merge pull request #18 from forecast-bio/release/v0.3.4b1
release: v0.3.4b1
2 parents 16d267d + e32bf0a commit 5c9cde8

16 files changed

Lines changed: 1403 additions & 16 deletions

File tree

.chainlink/issues.db

0 Bytes
Binary file not shown.

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

77
## [Unreleased]
88

9+
## [0.3.4b1] - 2026-02-04
10+
11+
### Added
12+
- **Content checksums**: Per-shard SHA-256 digests computed at write time across all storage backends (`LocalDiskStore`, `S3DataStore`, `PDSBlobStore`). Checksums are carried via `ShardWriteResult` and automatically merged into index entry metadata
13+
- **`verify_checksums()`**: Utility function to verify stored checksums against shard files on disk; remote URLs (`s3://`, `at://`, `http://`) are gracefully skipped
14+
- **`atdata verify` CLI command**: Verify content integrity of indexed datasets from the command line
15+
- **AT URI support in `load_dataset()`**: `load_dataset("at://did:plc:abc/.../rkey")` now fetches dataset records from ATProto and resolves storage (blobs, HTTP, S3) into streamable datasets with automatic schema decoding
16+
- **Lens composition operators**: `@` (compose) and `|` (pipe) operators for chaining lenses, plus `identity_lens()` factory for pass-through transforms
17+
918
## [0.3.3b2] - 2026-02-04
1019

1120
### Testing

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "atdata"
3-
version = "0.3.3b2"
3+
version = "0.3.4b1"
44
description = "A loose federation of distributed, typed datasets"
55
readme = "README.md"
66
authors = [

src/atdata/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@
106106
LocalDiskStore as LocalDiskStore,
107107
)
108108

109+
from ._helpers import (
110+
verify_checksums as verify_checksums,
111+
ShardWriteResult as ShardWriteResult,
112+
)
113+
109114
from ._cid import (
110115
generate_cid as generate_cid,
111116
verify_cid as verify_cid,

src/atdata/_helpers.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,37 @@
1-
"""Helper utilities for numpy array serialization.
1+
"""Helper utilities for numpy array serialization and content checksums.
22
33
This module provides utility functions for converting numpy arrays to and from
4-
bytes for msgpack serialization.
4+
bytes for msgpack serialization, as well as SHA-256 checksum utilities for
5+
verifying dataset shard integrity.
56
67
Functions:
78
- ``array_to_bytes()``: Serialize numpy array to bytes
89
- ``bytes_to_array()``: Deserialize bytes to numpy array
10+
- ``sha256_file()``: Compute SHA-256 hex digest of a file
11+
- ``sha256_bytes()``: Compute SHA-256 hex digest of in-memory bytes
12+
- ``verify_checksums()``: Verify stored checksums against shard data
913
10-
These helpers are used internally by ``PackableSample`` to enable transparent
11-
handling of NDArray fields during msgpack packing/unpacking.
14+
Classes:
15+
- ``ShardWriteResult``: ``list[str]`` subclass carrying per-shard checksums
1216
"""
1317

18+
from __future__ import annotations
19+
1420
##
1521
# Imports
1622

23+
import hashlib
1724
import struct
1825
from io import BytesIO
26+
from typing import TYPE_CHECKING
1927

2028
import numpy as np
2129

30+
if TYPE_CHECKING:
31+
from pathlib import Path
32+
33+
from atdata._protocols import IndexEntry
34+
2235
# .npy format magic prefix (used for backward-compatible deserialization)
2336
_NPY_MAGIC = b"\x93NUMPY"
2437

@@ -84,3 +97,114 @@ def bytes_to_array(b: bytes) -> np.ndarray:
8497
shape = struct.unpack_from(f"<{ndim}q", b, offset)
8598
offset += ndim * 8
8699
return np.frombuffer(b, dtype=dtype, offset=offset).reshape(shape).copy()
100+
101+
102+
##
103+
# Checksum utilities
104+
105+
106+
def sha256_file(path: str | Path, *, chunk_size: int = 8192) -> str:
107+
"""Compute SHA-256 hex digest of a file.
108+
109+
Reads the file in chunks to support large files without loading
110+
everything into memory.
111+
112+
Args:
113+
path: Path to the file.
114+
chunk_size: Read buffer size in bytes.
115+
116+
Returns:
117+
Hex-encoded SHA-256 digest string (64 characters).
118+
119+
Examples:
120+
>>> digest = sha256_file("/path/to/shard.tar")
121+
>>> len(digest)
122+
64
123+
"""
124+
h = hashlib.sha256()
125+
with open(path, "rb") as f:
126+
while chunk := f.read(chunk_size):
127+
h.update(chunk)
128+
return h.hexdigest()
129+
130+
131+
def sha256_bytes(data: bytes) -> str:
132+
"""Compute SHA-256 hex digest of in-memory bytes.
133+
134+
Args:
135+
data: Raw bytes to hash.
136+
137+
Returns:
138+
Hex-encoded SHA-256 digest string (64 characters).
139+
140+
Examples:
141+
>>> sha256_bytes(b"hello")
142+
'2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824'
143+
"""
144+
return hashlib.sha256(data).hexdigest()
145+
146+
147+
class ShardWriteResult(list):
148+
"""Return type carrying shard URLs and per-shard checksums.
149+
150+
Extends ``list[str]`` so it satisfies the ``AbstractDataStore.write_shards()``
151+
return type (``list[str]``), while also carrying SHA-256 checksum metadata.
152+
153+
Attributes:
154+
checksums: Dict mapping each shard URL to its SHA-256 hex digest.
155+
156+
Examples:
157+
>>> result = ShardWriteResult(["shard-0.tar"], {"shard-0.tar": "abcd..."})
158+
>>> result[0]
159+
'shard-0.tar'
160+
>>> result.checksums["shard-0.tar"]
161+
'abcd...'
162+
"""
163+
164+
checksums: dict[str, str]
165+
166+
def __init__(self, urls: list[str], checksums: dict[str, str]) -> None:
167+
super().__init__(urls)
168+
self.checksums = checksums
169+
170+
171+
def verify_checksums(entry: "IndexEntry") -> dict[str, str]:
172+
"""Verify SHA-256 checksums for all shards in an index entry.
173+
174+
Compares stored checksums (from ``entry.metadata["checksums"]``) against
175+
freshly computed digests. Shards without stored checksums are reported
176+
as ``"skipped"``.
177+
178+
Currently supports local file paths only. S3 and AT URIs are reported
179+
as ``"skipped"`` unless a corresponding checksum is absent.
180+
181+
Args:
182+
entry: An IndexEntry with ``data_urls`` and optional metadata checksums.
183+
184+
Returns:
185+
Dict mapping each shard URL to one of:
186+
``"ok"``, ``"mismatch"``, ``"skipped"``, or ``"error:<message>"``.
187+
188+
Examples:
189+
>>> results = verify_checksums(entry)
190+
>>> assert all(v == "ok" for v in results.values())
191+
"""
192+
stored: dict[str, str] = {}
193+
if entry.metadata and "checksums" in entry.metadata:
194+
stored = entry.metadata["checksums"]
195+
196+
results: dict[str, str] = {}
197+
for url in entry.data_urls:
198+
if url not in stored:
199+
results[url] = "skipped"
200+
continue
201+
# Only local file paths can be verified; skip remote URLs
202+
if url.startswith(("s3://", "at://", "http://", "https://")):
203+
results[url] = "skipped"
204+
continue
205+
try:
206+
actual = sha256_file(url)
207+
results[url] = "ok" if actual == stored[url] else "mismatch"
208+
except Exception as e:
209+
results[url] = f"error:{e}"
210+
return results

src/atdata/_hf_api.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
if TYPE_CHECKING:
5050
from ._protocols import AbstractIndex
51+
from .atmosphere.client import Atmosphere
5152

5253
##
5354
# Type variables
@@ -478,6 +479,118 @@ def _group_shards_by_split(shards: list[str]) -> dict[str, list[str]]:
478479
# Index-based path resolution
479480

480481

482+
def _is_at_uri(path: str) -> bool:
483+
"""Check if path is an AT Protocol URI (at://...).
484+
485+
Examples:
486+
>>> _is_at_uri("at://did:plc:abc123/ac.foundation.dataset.record/my-ds")
487+
True
488+
>>> _is_at_uri("@local/my-dataset")
489+
False
490+
"""
491+
return path.startswith("at://")
492+
493+
494+
def _resolve_at_uri(
495+
path: str,
496+
sample_type: Type[ST] | None = None,
497+
client: "Atmosphere | None" = None,
498+
) -> tuple[Dataset, Type]:
499+
"""Resolve an AT URI to a Dataset by fetching the record from ATProto.
500+
501+
Fetches the dataset record once, determines storage type (blobs, HTTP, S3),
502+
resolves shard URLs, and optionally decodes the schema to reconstruct
503+
the sample type.
504+
505+
Args:
506+
path: AT URI pointing to a dataset record.
507+
sample_type: Optional sample type class. If None, the schema is
508+
decoded from the referenced schema record.
509+
client: Optional Atmosphere client. If None, an unauthenticated
510+
client is created for public record access.
511+
512+
Returns:
513+
Tuple of (Dataset, resolved_type).
514+
515+
Raises:
516+
ValueError: If the record is not a dataset record, has no
517+
resolvable storage, or uses an unknown storage type.
518+
"""
519+
from .atmosphere.client import Atmosphere
520+
from .atmosphere._types import AtUri, LEXICON_NAMESPACE
521+
from ._sources import BlobSource
522+
523+
if client is None:
524+
client = Atmosphere()
525+
526+
# Single fetch — all routing derived from this dict
527+
record = client.get_record(path)
528+
expected_type = f"{LEXICON_NAMESPACE}.record"
529+
if record.get("$type") != expected_type:
530+
raise ValueError(
531+
f"Record at {path} is not a dataset record. "
532+
f"Expected $type='{expected_type}', got '{record.get('$type')}'"
533+
)
534+
535+
storage = record.get("storage", {})
536+
storage_type = storage.get("$type", "")
537+
538+
if "storageBlobs" in storage_type:
539+
parsed = AtUri.parse(path)
540+
did = parsed.authority
541+
refs = []
542+
for entry in storage.get("blobs", []):
543+
blob = entry.get("blob", entry)
544+
ref = blob.get("ref", {})
545+
cid = ref.get("$link") if isinstance(ref, dict) else str(ref)
546+
if cid:
547+
refs.append({"did": did, "cid": cid})
548+
pds_endpoint = client._resolve_pds_endpoint(did)
549+
source: DataSource = BlobSource(blob_refs=refs, pds_endpoint=pds_endpoint)
550+
elif "storageHttp" in storage_type:
551+
urls = [s["url"] for s in storage.get("shards", [])]
552+
if not urls:
553+
raise ValueError(f"Dataset record at {path} has no storage URLs")
554+
source = URLSource(_shards_to_wds_url(urls))
555+
elif "storageS3" in storage_type:
556+
bucket = storage.get("bucket", "")
557+
endpoint = storage.get("endpoint")
558+
urls = []
559+
for s in storage.get("shards", []):
560+
if endpoint:
561+
urls.append(f"{endpoint.rstrip('/')}/{bucket}/{s['key']}")
562+
else:
563+
urls.append(f"s3://{bucket}/{s['key']}")
564+
if not urls:
565+
raise ValueError(f"Dataset record at {path} has no storage URLs")
566+
source = URLSource(_shards_to_wds_url(urls))
567+
elif "storageExternal" in storage_type:
568+
urls = storage.get("urls", [])
569+
if not urls:
570+
raise ValueError(f"Dataset record at {path} has no storage URLs")
571+
source = URLSource(_shards_to_wds_url(urls))
572+
else:
573+
raise ValueError(f"Unknown storage type in dataset record: {storage_type}")
574+
575+
# Resolve sample type from the already-fetched record
576+
if sample_type is None:
577+
schema_ref = record.get("schemaRef")
578+
if schema_ref:
579+
from .atmosphere.schema import SchemaLoader
580+
from ._schema_codec import schema_to_type
581+
582+
schema_loader = SchemaLoader(client)
583+
schema_record = schema_loader.get(schema_ref)
584+
resolved_type = schema_to_type(schema_record)
585+
else:
586+
resolved_type = DictSample
587+
else:
588+
resolved_type = sample_type
589+
590+
ds = Dataset[resolved_type](source)
591+
return ds, resolved_type
592+
593+
481594
def _is_indexed_path(path: str) -> bool:
482595
"""Check if path uses @handle/dataset notation for index lookup.
483596
@@ -680,6 +793,7 @@ def load_dataset(
680793
681794
Args:
682795
path: Path to dataset. Can be:
796+
- AT URI: "at://did:plc:abc/ac.foundation.dataset.record/rkey"
683797
- Index lookup: "@handle/dataset-name" or "@local/dataset-name"
684798
- WebDataset brace notation: "path/to/{train,test}-{000..099}.tar"
685799
- Local directory: "./data/" (scans for .tar files)
@@ -746,6 +860,18 @@ def load_dataset(
746860
sample_type.__name__ if sample_type is not None else "None",
747861
)
748862

863+
# Handle at:// AT Protocol URI resolution
864+
if _is_at_uri(path):
865+
log.debug("load_dataset: resolving AT URI %s", path)
866+
ds, resolved_type = _resolve_at_uri(path, sample_type)
867+
868+
if split is not None:
869+
return ds
870+
871+
return DatasetDict(
872+
{"train": ds}, sample_type=resolved_type, streaming=streaming
873+
)
874+
749875
# Handle @handle/dataset indexed path resolution
750876
if _is_indexed_path(path):
751877
if index is None:

0 commit comments

Comments
 (0)