|
48 | 48 |
|
49 | 49 | if TYPE_CHECKING: |
50 | 50 | from ._protocols import AbstractIndex |
| 51 | + from .atmosphere.client import Atmosphere |
51 | 52 |
|
52 | 53 | ## |
53 | 54 | # Type variables |
@@ -478,6 +479,118 @@ def _group_shards_by_split(shards: list[str]) -> dict[str, list[str]]: |
478 | 479 | # Index-based path resolution |
479 | 480 |
|
480 | 481 |
|
| 482 | +def _is_at_uri(path: str) -> bool: |
| 483 | + """Check if path is an AT Protocol URI (at://...). |
| 484 | +
|
| 485 | + Examples: |
| 486 | + >>> _is_at_uri("at://did:plc:abc123/ac.foundation.dataset.record/my-ds") |
| 487 | + True |
| 488 | + >>> _is_at_uri("@local/my-dataset") |
| 489 | + False |
| 490 | + """ |
| 491 | + return path.startswith("at://") |
| 492 | + |
| 493 | + |
| 494 | +def _resolve_at_uri( |
| 495 | + path: str, |
| 496 | + sample_type: Type[ST] | None = None, |
| 497 | + client: "Atmosphere | None" = None, |
| 498 | +) -> tuple[Dataset, Type]: |
| 499 | + """Resolve an AT URI to a Dataset by fetching the record from ATProto. |
| 500 | +
|
| 501 | + Fetches the dataset record once, determines storage type (blobs, HTTP, S3), |
| 502 | + resolves shard URLs, and optionally decodes the schema to reconstruct |
| 503 | + the sample type. |
| 504 | +
|
| 505 | + Args: |
| 506 | + path: AT URI pointing to a dataset record. |
| 507 | + sample_type: Optional sample type class. If None, the schema is |
| 508 | + decoded from the referenced schema record. |
| 509 | + client: Optional Atmosphere client. If None, an unauthenticated |
| 510 | + client is created for public record access. |
| 511 | +
|
| 512 | + Returns: |
| 513 | + Tuple of (Dataset, resolved_type). |
| 514 | +
|
| 515 | + Raises: |
| 516 | + ValueError: If the record is not a dataset record, has no |
| 517 | + resolvable storage, or uses an unknown storage type. |
| 518 | + """ |
| 519 | + from .atmosphere.client import Atmosphere |
| 520 | + from .atmosphere._types import AtUri, LEXICON_NAMESPACE |
| 521 | + from ._sources import BlobSource |
| 522 | + |
| 523 | + if client is None: |
| 524 | + client = Atmosphere() |
| 525 | + |
| 526 | + # Single fetch — all routing derived from this dict |
| 527 | + record = client.get_record(path) |
| 528 | + expected_type = f"{LEXICON_NAMESPACE}.record" |
| 529 | + if record.get("$type") != expected_type: |
| 530 | + raise ValueError( |
| 531 | + f"Record at {path} is not a dataset record. " |
| 532 | + f"Expected $type='{expected_type}', got '{record.get('$type')}'" |
| 533 | + ) |
| 534 | + |
| 535 | + storage = record.get("storage", {}) |
| 536 | + storage_type = storage.get("$type", "") |
| 537 | + |
| 538 | + if "storageBlobs" in storage_type: |
| 539 | + parsed = AtUri.parse(path) |
| 540 | + did = parsed.authority |
| 541 | + refs = [] |
| 542 | + for entry in storage.get("blobs", []): |
| 543 | + blob = entry.get("blob", entry) |
| 544 | + ref = blob.get("ref", {}) |
| 545 | + cid = ref.get("$link") if isinstance(ref, dict) else str(ref) |
| 546 | + if cid: |
| 547 | + refs.append({"did": did, "cid": cid}) |
| 548 | + pds_endpoint = client._resolve_pds_endpoint(did) |
| 549 | + source: DataSource = BlobSource(blob_refs=refs, pds_endpoint=pds_endpoint) |
| 550 | + elif "storageHttp" in storage_type: |
| 551 | + urls = [s["url"] for s in storage.get("shards", [])] |
| 552 | + if not urls: |
| 553 | + raise ValueError(f"Dataset record at {path} has no storage URLs") |
| 554 | + source = URLSource(_shards_to_wds_url(urls)) |
| 555 | + elif "storageS3" in storage_type: |
| 556 | + bucket = storage.get("bucket", "") |
| 557 | + endpoint = storage.get("endpoint") |
| 558 | + urls = [] |
| 559 | + for s in storage.get("shards", []): |
| 560 | + if endpoint: |
| 561 | + urls.append(f"{endpoint.rstrip('/')}/{bucket}/{s['key']}") |
| 562 | + else: |
| 563 | + urls.append(f"s3://{bucket}/{s['key']}") |
| 564 | + if not urls: |
| 565 | + raise ValueError(f"Dataset record at {path} has no storage URLs") |
| 566 | + source = URLSource(_shards_to_wds_url(urls)) |
| 567 | + elif "storageExternal" in storage_type: |
| 568 | + urls = storage.get("urls", []) |
| 569 | + if not urls: |
| 570 | + raise ValueError(f"Dataset record at {path} has no storage URLs") |
| 571 | + source = URLSource(_shards_to_wds_url(urls)) |
| 572 | + else: |
| 573 | + raise ValueError(f"Unknown storage type in dataset record: {storage_type}") |
| 574 | + |
| 575 | + # Resolve sample type from the already-fetched record |
| 576 | + if sample_type is None: |
| 577 | + schema_ref = record.get("schemaRef") |
| 578 | + if schema_ref: |
| 579 | + from .atmosphere.schema import SchemaLoader |
| 580 | + from ._schema_codec import schema_to_type |
| 581 | + |
| 582 | + schema_loader = SchemaLoader(client) |
| 583 | + schema_record = schema_loader.get(schema_ref) |
| 584 | + resolved_type = schema_to_type(schema_record) |
| 585 | + else: |
| 586 | + resolved_type = DictSample |
| 587 | + else: |
| 588 | + resolved_type = sample_type |
| 589 | + |
| 590 | + ds = Dataset[resolved_type](source) |
| 591 | + return ds, resolved_type |
| 592 | + |
| 593 | + |
481 | 594 | def _is_indexed_path(path: str) -> bool: |
482 | 595 | """Check if path uses @handle/dataset notation for index lookup. |
483 | 596 |
|
@@ -680,6 +793,7 @@ def load_dataset( |
680 | 793 |
|
681 | 794 | Args: |
682 | 795 | path: Path to dataset. Can be: |
| 796 | + - AT URI: "at://did:plc:abc/ac.foundation.dataset.record/rkey" |
683 | 797 | - Index lookup: "@handle/dataset-name" or "@local/dataset-name" |
684 | 798 | - WebDataset brace notation: "path/to/{train,test}-{000..099}.tar" |
685 | 799 | - Local directory: "./data/" (scans for .tar files) |
@@ -746,6 +860,18 @@ def load_dataset( |
746 | 860 | sample_type.__name__ if sample_type is not None else "None", |
747 | 861 | ) |
748 | 862 |
|
| 863 | + # Handle at:// AT Protocol URI resolution |
| 864 | + if _is_at_uri(path): |
| 865 | + log.debug("load_dataset: resolving AT URI %s", path) |
| 866 | + ds, resolved_type = _resolve_at_uri(path, sample_type) |
| 867 | + |
| 868 | + if split is not None: |
| 869 | + return ds |
| 870 | + |
| 871 | + return DatasetDict( |
| 872 | + {"train": ds}, sample_type=resolved_type, streaming=streaming |
| 873 | + ) |
| 874 | + |
749 | 875 | # Handle @handle/dataset indexed path resolution |
750 | 876 | if _is_indexed_path(path): |
751 | 877 | if index is None: |
|
0 commit comments