From 25823eddf9e713a5945c37ce4a66f99f8ace9c61 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <220467675+maxine-at-forecast@users.noreply.github.com> Date: Wed, 1 Apr 2026 15:12:18 -0700 Subject: [PATCH 1/2] refactor: phases 3-5 architectural refactoring (#84) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3 — Internal delegation + AppView + temporal coupling: - Move insert_dataset/write_samples logic from Index to Repository - Extract with_appview_fallback() helper to atmosphere/_appview.py - Switch _AtmosphereBackend to eager init, remove _ensure_loaders() Phase 4 — Provider fixes: - Fix Redis label resolution with created_at timestamp ordering - Fix Redis N+1 query in iter_entries using pipeline batch fetch - Add SQLite indexes for schema (name,version) and labels (name,created_at) Phase 5 — Atmosphere namespacing + cleanup: - Add RecordOps, BlobOps, XrpcClient namespace classes on Atmosphere - Switch all 18 DeprecationWarning sites to FutureWarning (visible by default) - Fix thread-safety of _atproto_client_class with double-checked locking - Split StubManager into SchemaStubManager/LensStubManager with _StubWriter base All 1735 tests pass. Co-Authored-By: Claude Opus 4.6 --- src/atdata/__init__.py | 4 +- src/atdata/_schema_codec.py | 2 +- src/atdata/_sources.py | 4 +- src/atdata/_stub_manager.py | 439 ++++++++++----------- src/atdata/atmosphere/__init__.py | 4 +- src/atdata/atmosphere/_appview.py | 56 +++ src/atdata/atmosphere/_types.py | 2 +- src/atdata/atmosphere/client.py | 167 +++++++- src/atdata/atmosphere/labels.py | 18 +- src/atdata/atmosphere/lens.py | 44 +-- src/atdata/atmosphere/records.py | 95 ++--- src/atdata/atmosphere/schema.py | 38 +- src/atdata/dataset.py | 2 +- src/atdata/index/_index.py | 135 +------ src/atdata/local/_repo_legacy.py | 2 +- src/atdata/providers/_redis.py | 42 +- src/atdata/providers/_sqlite.py | 6 + src/atdata/repository.py | 227 +++++++++-- tests/test_atmosphere.py | 8 +- tests/test_atmosphere_label_integration.py | 4 +- tests/test_dataset.py | 2 +- tests/test_index_write.py | 10 +- tests/test_local.py | 44 +-- tests/test_protocols.py | 2 +- tests/test_repository.py | 2 +- tests/test_repository_coverage.py | 40 +- tests/test_sources.py | 10 +- tests/test_workflow_atmosphere.py | 2 +- tests/test_workflow_cross_backend.py | 2 +- tests/test_workflow_local.py | 18 +- 30 files changed, 816 insertions(+), 615 deletions(-) create mode 100644 src/atdata/atmosphere/_appview.py diff --git a/src/atdata/__init__.py b/src/atdata/__init__.py index 511c006..6e306ae 100644 --- a/src/atdata/__init__.py +++ b/src/atdata/__init__.py @@ -155,7 +155,7 @@ def __getattr__(name: str): warnings.warn( "atdata.AbstractIndex is deprecated. Use atdata.Index directly " "as the type annotation instead.", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) from ._protocols import AbstractIndex @@ -174,7 +174,7 @@ def schema_to_type(schema: dict, *, use_cache: bool = True): warnings.warn( "atdata.schema_to_type() is deprecated, use index.get_schema_type() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return _schema_to_type(schema, use_cache=use_cache) diff --git a/src/atdata/_schema_codec.py b/src/atdata/_schema_codec.py index 2ae8738..b3829a2 100644 --- a/src/atdata/_schema_codec.py +++ b/src/atdata/_schema_codec.py @@ -684,7 +684,7 @@ def schema_to_type( warnings.warn( "schema_to_type() is deprecated, use index.get_schema_type() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return _schema_to_type(schema, use_cache=use_cache) diff --git a/src/atdata/_sources.py b/src/atdata/_sources.py index 024e852..b268e9b 100644 --- a/src/atdata/_sources.py +++ b/src/atdata/_sources.py @@ -75,7 +75,7 @@ def shard_list(self) -> list[str]: warnings.warn( "shard_list is deprecated, use list_shards()", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.list_shards() @@ -199,7 +199,7 @@ def shard_list(self) -> list[str]: warnings.warn( "shard_list is deprecated, use list_shards()", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.list_shards() diff --git a/src/atdata/_stub_manager.py b/src/atdata/_stub_manager.py index 28a443d..cc01eab 100644 --- a/src/atdata/_stub_manager.py +++ b/src/atdata/_stub_manager.py @@ -71,53 +71,17 @@ def _extract_authority(schema_ref: Optional[str]) -> str: return DEFAULT_AUTHORITY -class StubManager: - """Manages automatic generation of Python modules for decoded schemas. - - The StubManager handles: - - Determining module file paths from schema metadata - - Checking if modules exist and are current - - Generating modules atomically (write to temp, rename) - - Creating __init__.py files for proper package structure - - Importing classes from generated modules - - Cleaning up old modules - - Modules are organized by authority (from the schema ref URI) to avoid - collisions between schemas with the same name from different sources:: - - ~/.atdata/stubs/ - __init__.py - local/ - __init__.py - MySample_1_0_0.py - alice.bsky.social/ - __init__.py - MySample_1_0_0.py - did_plc_abc123/ - __init__.py - OtherSample_2_0_0.py +class _StubWriter: + """Shared infrastructure for writing stub modules atomically. - Args: - stub_dir: Directory to write module files. Defaults to ``~/.atdata/stubs/``. - - Examples: - >>> manager = StubManager() - >>> schema_dict = {"name": "MySample", "version": "1.0.0", "fields": [...]} - >>> SampleClass = manager.ensure_module(schema_dict) - >>> print(manager.stub_dir) - /Users/you/.atdata/stubs + Handles directory creation, ``__init__.py`` maintenance, and + atomic file writes via temp-file-then-rename. Subclasses provide + the filename convention and content generation. """ - def __init__(self, stub_dir: Optional[Union[str, Path]] = None): - if stub_dir is None: - self._stub_dir = DEFAULT_STUB_DIR - else: - self._stub_dir = Path(stub_dir) - + def __init__(self, stub_dir: Path) -> None: + self._stub_dir = stub_dir self._initialized = False - self._first_generation = True - # Cache of imported classes: (authority, name, version) -> class - self._class_cache: dict[tuple[str, str, str], Type] = {} @property def stub_dir(self) -> Path: @@ -128,69 +92,11 @@ def _ensure_dir_exists(self) -> None: """Create stub directory with __init__.py if it doesn't exist.""" if not self._initialized: self._stub_dir.mkdir(parents=True, exist_ok=True) - # Create root __init__.py init_path = self._stub_dir / "__init__.py" if not init_path.exists(): init_path.write_text('"""Auto-generated atdata schema modules."""\n') self._initialized = True - def _module_filename(self, name: str, version: str) -> str: - """Generate module filename from schema name and version. - - Replaces dots in version with underscores to avoid confusion - with file extensions. - - Args: - name: Schema name (e.g., "MySample") - version: Schema version (e.g., "1.0.0") - - Returns: - Filename like "MySample_1_0_0.py" - """ - safe_version = version.replace(".", "_") - return f"{name}_{safe_version}.py" - - def _module_path( - self, name: str, version: str, authority: str = DEFAULT_AUTHORITY - ) -> Path: - """Get full path to module file for a schema. - - Args: - name: Schema name - version: Schema version - authority: Authority from schema ref (e.g., "local", "alice.bsky.social") - - Returns: - Path like ~/.atdata/stubs/local/MySample_1_0_0.py - """ - return self._stub_dir / authority / self._module_filename(name, version) - - def _module_is_current(self, path: Path, version: str) -> bool: - """Check if an existing module file matches the expected version. - - Reads the module docstring to extract the version and compares - it to the expected version. - - Args: - path: Path to the module file - version: Expected schema version - - Returns: - True if module exists and version matches - """ - if not path.exists(): - return False - - try: - with open(path, "r", encoding="utf-8") as f: - content = f.read(500) # Read first 500 chars for docstring - match = _VERSION_PATTERN.search(content) - if match: - return match.group(1) == version - return False - except (OSError, IOError): - return False - def _ensure_authority_package(self, authority: str) -> None: """Ensure authority subdirectory exists with __init__.py.""" self._ensure_dir_exists() @@ -218,55 +124,122 @@ def _write_module_atomic(self, path: Path, content: str, authority: str) -> None # Create temp file in same directory for atomic rename fd, temp_path = tempfile.mkstemp( suffix=".py.tmp", - dir=path.parent, # Use parent dir (authority subdir) for atomic rename + dir=path.parent, ) temp_path = Path(temp_path) try: with os.fdopen(fd, "w", encoding="utf-8") as f: - # Try to get exclusive lock (non-blocking, ignore if unavailable) - # File locking is best-effort - not all filesystems support it try: fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) except (OSError, IOError): - # Lock unavailable (NFS, Windows, etc.) - proceed without lock - # Atomic rename provides the real protection pass f.write(content) f.flush() os.fsync(f.fileno()) - # Atomic rename (on POSIX systems) temp_path.rename(path) except Exception: - # Clean up temp file on error - best effort, ignore failures try: temp_path.unlink() except OSError: - pass # Temp file cleanup failed, re-raising original error + pass raise - def ensure_stub(self, schema: dict) -> Optional[Path]: - """Ensure a module file exists for the given schema. + def _clean_empty_dirs(self) -> None: + """Remove empty authority directories (including lone __init__.py).""" + if not self._stub_dir.exists(): + return + for subdir in self._stub_dir.iterdir(): + if subdir.is_dir(): + contents = list(subdir.iterdir()) + if len(contents) == 0: + try: + subdir.rmdir() + except OSError: + continue + elif len(contents) == 1 and contents[0].name == "__init__.py": + try: + contents[0].unlink() + subdir.rmdir() + except OSError: + continue + + +class SchemaStubManager(_StubWriter): + """Manages automatic generation of Python modules for decoded schemas. + + Modules are organised by authority (from the schema ref URI) to avoid + collisions between schemas with the same name from different sources:: + + ~/.atdata/stubs/ + __init__.py + local/ + __init__.py + MySample_1_0_0.py + alice.bsky.social/ + __init__.py + MySample_1_0_0.py + + Args: + stub_dir: Directory to write module files. Defaults to ``~/.atdata/stubs/``. + + Examples: + >>> manager = SchemaStubManager() + >>> schema_dict = {"name": "MySample", "version": "1.0.0", "fields": [...]} + >>> SampleClass = manager.ensure_module(schema_dict) + """ + + def __init__(self, stub_dir: Optional[Union[str, Path]] = None): + super().__init__(Path(stub_dir) if stub_dir else DEFAULT_STUB_DIR) + self._first_generation = True + self._class_cache: dict[tuple[str, str, str], Type] = {} + + def _module_filename(self, name: str, version: str) -> str: + """Generate module filename from schema name and version. + + Args: + name: Schema name (e.g., "MySample") + version: Schema version (e.g., "1.0.0") + + Returns: + Filename like "MySample_1_0_0.py" + """ + safe_version = version.replace(".", "_") + return f"{name}_{safe_version}.py" + + def _module_path( + self, name: str, version: str, authority: str = DEFAULT_AUTHORITY + ) -> Path: + """Get full path to module file for a schema.""" + return self._stub_dir / authority / self._module_filename(name, version) + + def _module_is_current(self, path: Path, version: str) -> bool: + """Check if an existing module file matches the expected version.""" + if not path.exists(): + return False - If a current module already exists, returns its path without - regenerating. Otherwise, generates the module and writes it. + try: + with open(path, "r", encoding="utf-8") as f: + content = f.read(500) + match = _VERSION_PATTERN.search(content) + if match: + return match.group(1) == version + return False + except (OSError, IOError): + return False - Modules are namespaced by the authority from the schema's $ref URI - to avoid collisions between schemas with the same name from - different sources. + def ensure_stub(self, schema: dict) -> Optional[Path]: + """Ensure a module file exists for the given schema. Args: schema: Schema dict with 'name', 'version', and 'fields' keys. - Can also be a LocalSchemaRecord (supports dict-style access). - Should include '$ref' for proper namespacing. Returns: Path to the module file, or None if schema is missing required fields. """ - # Extract schema metadata (works with dict or LocalSchemaRecord) name = schema.get("name") if hasattr(schema, "get") else None version = schema.get("version", "1.0.0") if hasattr(schema, "get") else "1.0.0" schema_ref = schema.get("$ref") if hasattr(schema, "get") else None @@ -274,16 +247,12 @@ def ensure_stub(self, schema: dict) -> Optional[Path]: if not name: return None - # Extract authority from schema ref for namespacing authority = _extract_authority(schema_ref) path = self._module_path(name, version, authority) - # Skip if current module exists if self._module_is_current(path, version): return path - # Generate and write module - # Convert to dict if needed for generate_module if hasattr(schema, "to_dict"): schema_dict = schema.to_dict() else: @@ -292,7 +261,6 @@ def ensure_stub(self, schema: dict) -> Optional[Path]: content = generate_module(schema_dict) self._write_module_atomic(path, content, authority) - # Print helpful message on first generation if self._first_generation: self._first_generation = False self._print_ide_hint() @@ -302,20 +270,12 @@ def ensure_stub(self, schema: dict) -> Optional[Path]: def ensure_module(self, schema: dict) -> Optional[Type]: """Ensure a module exists and return the class from it. - This is the primary method for getting a properly-typed class from - a schema. It generates the module if needed, imports the class, - and returns it with proper type information. - Args: schema: Schema dict with 'name', 'version', and 'fields' keys. - Can also be a LocalSchemaRecord (supports dict-style access). - Should include '$ref' for proper namespacing. Returns: - The PackableSample subclass from the generated module, or None - if schema is missing required fields. + The PackableSample subclass from the generated module, or None. """ - # Extract schema metadata name = schema.get("name") if hasattr(schema, "get") else None version = schema.get("version", "1.0.0") if hasattr(schema, "get") else "1.0.0" schema_ref = schema.get("$ref") if hasattr(schema, "get") else None @@ -325,17 +285,14 @@ def ensure_module(self, schema: dict) -> Optional[Type]: authority = _extract_authority(schema_ref) - # Check cache first cache_key = (authority, name, version) if cache_key in self._class_cache: return self._class_cache[cache_key] - # Ensure module exists path = self.ensure_stub(schema) if path is None: return None - # Import and cache the class cls = self._import_class_from_module(path, name) if cls is not None: self._class_cache[cache_key] = cls @@ -345,40 +302,20 @@ def ensure_module(self, schema: dict) -> Optional[Type]: def _import_class_from_module( self, module_path: Path, class_name: str ) -> Optional[Type]: - """Import a class from a generated module file. - - Uses importlib to dynamically load the module and extract the class. - - Args: - module_path: Path to the .py module file - class_name: Name of the class to import - - Returns: - The imported class, or None if import fails - """ + """Import a class from a generated module file.""" if not module_path.exists(): return None try: - # Create a unique module name based on the path module_name = f"_atdata_generated_{module_path.stem}" - - # Load the module spec spec = importlib.util.spec_from_file_location(module_name, module_path) if spec is None or spec.loader is None: return None - - # Create and execute the module module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) - - # Get the class from the module - cls = getattr(module, class_name, None) - return cls - + return getattr(module, class_name, None) except (ModuleNotFoundError, AttributeError, ImportError, OSError): - # Import failed - return None and let caller fall back to dynamic generation return None def _print_ide_hint(self) -> None: @@ -397,52 +334,25 @@ def _print_ide_hint(self) -> None: def get_stub_path( self, name: str, version: str, authority: str = DEFAULT_AUTHORITY ) -> Optional[Path]: - """Get the path to an existing stub file. - - Args: - name: Schema name - version: Schema version - authority: Authority namespace (default: "local") - - Returns: - Path if stub exists, None otherwise - """ + """Get the path to an existing stub file.""" path = self._module_path(name, version, authority) return path if path.exists() else None def list_stubs(self, authority: Optional[str] = None) -> list[Path]: - """List all module files in the stub directory. - - Args: - authority: If provided, only list modules for this authority. - If None, lists all modules across all authorities. - - Returns: - List of paths to existing module files (excludes __init__.py) - """ + """List all schema module files in the stub directory.""" if not self._stub_dir.exists(): return [] if authority: - # List modules for specific authority authority_dir = self._stub_dir / authority if not authority_dir.exists(): return [] return [p for p in authority_dir.glob("*.py") if p.name != "__init__.py"] - # List all modules across all authorities (recursive, excluding __init__.py) return [p for p in self._stub_dir.glob("**/*.py") if p.name != "__init__.py"] def clear_stubs(self, authority: Optional[str] = None) -> int: - """Remove module files from the stub directory. - - Args: - authority: If provided, only clear modules for this authority. - If None, clears all modules across all authorities. - - Returns: - Number of files removed - """ + """Remove schema module files from the stub directory.""" stubs = self.list_stubs(authority) removed = 0 for path in stubs: @@ -450,10 +360,8 @@ def clear_stubs(self, authority: Optional[str] = None) -> int: path.unlink() removed += 1 except OSError: - # File already removed or permission denied - skip and continue continue - # Clear the class cache for removed modules if authority: keys_to_remove = [k for k in self._class_cache if k[0] == authority] else: @@ -461,32 +369,45 @@ def clear_stubs(self, authority: Optional[str] = None) -> int: for key in keys_to_remove: del self._class_cache[key] - # Clean up empty authority directories (including __init__.py) - if self._stub_dir.exists(): - for subdir in self._stub_dir.iterdir(): - if subdir.is_dir(): - # Check if only __init__.py remains - contents = list(subdir.iterdir()) - if len(contents) == 0: - try: - subdir.rmdir() - except OSError: - continue - elif len(contents) == 1 and contents[0].name == "__init__.py": - try: - contents[0].unlink() - subdir.rmdir() - except OSError: - continue - + self._clean_empty_dirs() return removed - # ------------------------------------------------------------------ - # Lens stub management - # ------------------------------------------------------------------ + def clear_stub( + self, name: str, version: str, authority: str = DEFAULT_AUTHORITY + ) -> bool: + """Remove a specific schema module file.""" + path = self._module_path(name, version, authority) + if path.exists(): + try: + path.unlink() + cache_key = (authority, name, version) + if cache_key in self._class_cache: + del self._class_cache[cache_key] + return True + except OSError: + return False + return False + + +class LensStubManager(_StubWriter): + """Manages automatic generation of Python modules for decoded lenses. + + Lens stubs follow the same directory layout as schema stubs but use + a ``lens_`` filename prefix to avoid collisions. + + Args: + stub_dir: Directory to write module files. Defaults to ``~/.atdata/stubs/``. + + Examples: + >>> manager = LensStubManager() + >>> path = manager.ensure_lens_stub({"name": "my_lens", "version": "1.0.0"}) + """ + + def __init__(self, stub_dir: Optional[Union[str, Path]] = None): + super().__init__(Path(stub_dir) if stub_dir else DEFAULT_STUB_DIR) def _lens_module_filename(self, name: str, version: str) -> str: - """Generate lens module filename from lens name and version. + """Generate lens module filename. Args: name: Lens name (e.g., "image_to_grayscale") @@ -522,7 +443,6 @@ def ensure_lens_stub(self, record: dict) -> Optional[Path]: authority = DEFAULT_AUTHORITY path = self._lens_module_path(name, version, authority) - # Skip if file already exists if path.exists(): return path @@ -532,14 +452,7 @@ def ensure_lens_stub(self, record: dict) -> Optional[Path]: return path def list_lens_stubs(self, authority: Optional[str] = None) -> list[Path]: - """List all lens stub files in the stub directory. - - Args: - authority: If provided, only list lens stubs for this authority. - - Returns: - List of paths to lens stub files. - """ + """List all lens stub files in the stub directory.""" if not self._stub_dir.exists(): return [] @@ -554,35 +467,79 @@ def list_lens_stubs(self, authority: Optional[str] = None) -> list[Path]: p for p in self._stub_dir.glob(f"**/{pattern}") if p.name != "__init__.py" ] + +class StubManager: + """Backward-compatible facade composing SchemaStubManager and LensStubManager. + + New code should use ``SchemaStubManager`` or ``LensStubManager`` directly. + + Args: + stub_dir: Directory to write module files. Defaults to ``~/.atdata/stubs/``. + + Examples: + >>> manager = StubManager() + >>> schema_dict = {"name": "MySample", "version": "1.0.0", "fields": [...]} + >>> SampleClass = manager.ensure_module(schema_dict) + >>> print(manager.stub_dir) + /Users/you/.atdata/stubs + """ + + def __init__(self, stub_dir: Optional[Union[str, Path]] = None): + self._schemas = SchemaStubManager(stub_dir) + self._lenses = LensStubManager(stub_dir) + + def __getattr__(self, name: str): + """Delegate attribute access to the schema sub-manager for backward compat.""" + return getattr(self._schemas, name) + + @property + def stub_dir(self) -> Path: + """The directory where module files are written.""" + return self._schemas.stub_dir + + # Schema delegation + def ensure_stub(self, schema: dict) -> Optional[Path]: + """Ensure a schema module file exists. See :meth:`SchemaStubManager.ensure_stub`.""" + return self._schemas.ensure_stub(schema) + + def ensure_module(self, schema: dict) -> Optional[Type]: + """Ensure a schema module exists and return the class. See :meth:`SchemaStubManager.ensure_module`.""" + return self._schemas.ensure_module(schema) + + def get_stub_path( + self, name: str, version: str, authority: str = DEFAULT_AUTHORITY + ) -> Optional[Path]: + """Get path to an existing schema stub.""" + return self._schemas.get_stub_path(name, version, authority) + + def list_stubs(self, authority: Optional[str] = None) -> list[Path]: + """List all schema module files.""" + return self._schemas.list_stubs(authority) + + def clear_stubs(self, authority: Optional[str] = None) -> int: + """Remove schema module files.""" + return self._schemas.clear_stubs(authority) + def clear_stub( self, name: str, version: str, authority: str = DEFAULT_AUTHORITY ) -> bool: - """Remove a specific module file. + """Remove a specific schema module file.""" + return self._schemas.clear_stub(name, version, authority) - Args: - name: Schema name - version: Schema version - authority: Authority namespace (default: "local") + # Lens delegation + def ensure_lens_stub(self, record: dict) -> Optional[Path]: + """Ensure a lens stub file exists. See :meth:`LensStubManager.ensure_lens_stub`.""" + return self._lenses.ensure_lens_stub(record) - Returns: - True if file was removed, False if it didn't exist - """ - path = self._module_path(name, version, authority) - if path.exists(): - try: - path.unlink() - # Clear from class cache - cache_key = (authority, name, version) - if cache_key in self._class_cache: - del self._class_cache[cache_key] - return True - except OSError: - return False - return False + def list_lens_stubs(self, authority: Optional[str] = None) -> list[Path]: + """List all lens stub files.""" + return self._lenses.list_lens_stubs(authority) __all__ = [ "StubManager", + "SchemaStubManager", + "LensStubManager", "DEFAULT_STUB_DIR", "DEFAULT_AUTHORITY", ] diff --git a/src/atdata/atmosphere/__init__.py b/src/atdata/atmosphere/__init__.py index ba59604..23101b7 100644 --- a/src/atdata/atmosphere/__init__.py +++ b/src/atdata/atmosphere/__init__.py @@ -218,7 +218,7 @@ def __init__( warnings.warn( "AtmosphereIndex is deprecated. Use atdata.Index(atmosphere=client) " "instead for unified index access.", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) self.client = client @@ -405,7 +405,7 @@ def decode_schema(self, ref: str) -> "Type[Packable]": warnings.warn( "Atmosphere.decode_schema() is deprecated, use Atmosphere.get_schema_type() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.get_schema_type(ref) diff --git a/src/atdata/atmosphere/_appview.py b/src/atdata/atmosphere/_appview.py new file mode 100644 index 0000000..ccf0e4f --- /dev/null +++ b/src/atdata/atmosphere/_appview.py @@ -0,0 +1,56 @@ +"""AppView fallback helper for atmosphere operations. + +Provides a decorator that wraps ATProto AppView calls with automatic +fallback to client-side resolution when the AppView is unreachable or +returns an error. +""" + +from __future__ import annotations + +from typing import Any, Callable, TypeVar + +import httpx + +from .._exceptions import AppViewError + +T = TypeVar("T") + +_APPVIEW_ERRORS = ( + httpx.HTTPStatusError, + httpx.ConnectError, + httpx.TimeoutException, + AppViewError, +) + + +def with_appview_fallback( + appview_fn: Callable[..., T], + fallback_fn: Callable[..., T], + *, + client: Any, + operation: str, +) -> T: + """Try an AppView operation, falling back on network/HTTP errors. + + Args: + appview_fn: Callable that performs the AppView request. + fallback_fn: Callable that performs the client-side fallback. + client: The Atmosphere client (checked for ``has_appview``). + operation: Human-readable operation name for log messages. + + Returns: + Result from *appview_fn* if successful, otherwise from *fallback_fn*. + """ + if getattr(client, "has_appview", False) is True: + try: + return appview_fn() + except _APPVIEW_ERRORS: + from .._logging import get_logger + + get_logger().warning( + "AppView %s failed, falling back to client-side", + operation, + exc_info=True, + ) + + return fallback_fn() diff --git a/src/atdata/atmosphere/_types.py b/src/atdata/atmosphere/_types.py index 5d57f32..7898cc4 100644 --- a/src/atdata/atmosphere/_types.py +++ b/src/atdata/atmosphere/_types.py @@ -105,7 +105,7 @@ def __getattr__(name: str) -> Any: new_name, hint = _DEPRECATED_ALIASES[name] warnings.warn( f"{name} is deprecated: {hint}.", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) from . import _lexicon_types diff --git a/src/atdata/atmosphere/client.py b/src/atdata/atmosphere/client.py index 4ceb76d..422e987 100644 --- a/src/atdata/atmosphere/client.py +++ b/src/atdata/atmosphere/client.py @@ -8,25 +8,31 @@ from typing import Optional, Any +import threading + from ._types import AtUri, LEXICON_NAMESPACE -# Lazy import to avoid requiring atproto if not using atmosphere features +# Lazy import to avoid requiring atproto if not using atmosphere features. +# Protected by double-checked locking for thread safety. _atproto_client_class: Optional[type] = None +_atproto_lock = threading.Lock() def _get_atproto_client_class(): - """Lazily import the atproto Client class.""" + """Lazily import the atproto Client class (thread-safe).""" global _atproto_client_class if _atproto_client_class is None: - try: - from atproto import Client - - _atproto_client_class = Client - except ImportError as e: - raise ImportError( - "The 'atproto' package is required for ATProto integration. " - "Install it with: pip install atproto" - ) from e + with _atproto_lock: + if _atproto_client_class is None: + try: + from atproto import Client + + _atproto_client_class = Client + except ImportError as e: + raise ImportError( + "The 'atproto' package is required for ATProto integration. " + "Install it with: pip install atproto" + ) from e return _atproto_client_class @@ -1034,3 +1040,142 @@ def get_entry_stats( f"{LEXICON_NAMESPACE}.getEntryStats", params={"uri": uri, "period": period}, ) + + # ------------------------------------------------------------------ # + # Namespaced access — power-user operations grouped by concern + # ------------------------------------------------------------------ # + + @property + def records(self) -> RecordOps: + """Namespaced record operations (create, get, put, delete, list).""" + if not hasattr(self, "_records_ops"): + self._records_ops = RecordOps(self) + return self._records_ops + + @property + def blobs(self) -> BlobOps: + """Namespaced blob operations (upload, get, get_url).""" + if not hasattr(self, "_blobs_ops"): + self._blobs_ops = BlobOps(self) + return self._blobs_ops + + @property + def xrpc(self) -> XrpcClient: + """Namespaced XRPC transport (query, procedure).""" + if not hasattr(self, "_xrpc_client"): + self._xrpc_client = XrpcClient(self) + return self._xrpc_client + + +class RecordOps: + """Namespaced ATProto record operations. + + Accessed via ``atmo.records``. Provides create, get, put, delete, + and list operations on ATProto records. + """ + + __slots__ = ("_atmo",) + + def __init__(self, atmo: Atmosphere) -> None: + self._atmo = atmo + + def create( + self, + collection: str, + record: dict, + *, + rkey: Optional[str] = None, + validate: bool = False, + ) -> AtUri: + """Create a record. See :meth:`Atmosphere.create_record`.""" + return self._atmo.create_record(collection, record, rkey=rkey, validate=validate) + + def put( + self, + collection: str, + rkey: str, + record: dict, + *, + validate: bool = False, + swap_commit: Optional[str] = None, + ) -> AtUri: + """Create or update a record. See :meth:`Atmosphere.put_record`.""" + return self._atmo.put_record( + collection, rkey, record, validate=validate, swap_commit=swap_commit + ) + + def get(self, uri: str | AtUri) -> dict: + """Fetch a record by AT URI. See :meth:`Atmosphere.get_record`.""" + return self._atmo.get_record(uri) + + def delete( + self, + uri: str | AtUri, + *, + swap_commit: Optional[str] = None, + ) -> None: + """Delete a record. See :meth:`Atmosphere.delete_record`.""" + self._atmo.delete_record(uri, swap_commit=swap_commit) + + def list( + self, + collection: str, + *, + repo: Optional[str] = None, + limit: int = 100, + cursor: Optional[str] = None, + ) -> tuple[list[dict], Optional[str]]: + """List records in a collection. See :meth:`Atmosphere.list_records`.""" + return self._atmo.list_records(collection, repo=repo, limit=limit, cursor=cursor) + + +class BlobOps: + """Namespaced ATProto blob operations. + + Accessed via ``atmo.blobs``. Provides upload, download, and URL + generation for PDS blobs. + """ + + __slots__ = ("_atmo",) + + def __init__(self, atmo: Atmosphere) -> None: + self._atmo = atmo + + def upload( + self, + data: bytes, + mime_type: str = "application/octet-stream", + *, + timeout: float | None = None, + ) -> dict: + """Upload a blob. See :meth:`Atmosphere.upload_blob`.""" + return self._atmo.upload_blob(data, mime_type, timeout=timeout) + + def get(self, did: str, cid: str) -> bytes: + """Download a blob. See :meth:`Atmosphere.get_blob`.""" + return self._atmo.get_blob(did, cid) + + def get_url(self, did: str, cid: str) -> str: + """Get direct URL for a blob. See :meth:`Atmosphere.get_blob_url`.""" + return self._atmo.get_blob_url(did, cid) + + +class XrpcClient: + """Namespaced XRPC transport operations. + + Accessed via ``atmo.xrpc``. Provides raw query (GET) and procedure + (POST) calls to the AppView. + """ + + __slots__ = ("_atmo",) + + def __init__(self, atmo: Atmosphere) -> None: + self._atmo = atmo + + def query(self, nsid: str, params: dict | None = None) -> dict: + """Call an XRPC query (GET). See :meth:`Atmosphere.xrpc_query`.""" + return self._atmo.xrpc_query(nsid, params=params) + + def procedure(self, nsid: str, input: dict | None = None) -> dict: + """Call an XRPC procedure (POST). See :meth:`Atmosphere.xrpc_procedure`.""" + return self._atmo.xrpc_procedure(nsid, input=input) diff --git a/src/atdata/atmosphere/labels.py b/src/atdata/atmosphere/labels.py index 1a98bd9..4c97966 100644 --- a/src/atdata/atmosphere/labels.py +++ b/src/atdata/atmosphere/labels.py @@ -203,18 +203,14 @@ def resolve( Raises: KeyError: If no matching label is found. """ - if getattr(self.client, "has_appview", False) is True: - try: - return self._resolve_via_appview(handle_or_did, name, version) - except Exception: - from .._logging import get_logger - - get_logger().warning( - "AppView label resolution failed, falling back to client-side", - exc_info=True, - ) + from ._appview import with_appview_fallback - return self._resolve_client_side(handle_or_did, name, version) + return with_appview_fallback( + lambda: self._resolve_via_appview(handle_or_did, name, version), + lambda: self._resolve_client_side(handle_or_did, name, version), + client=self.client, + operation="label resolution", + ) def _resolve_via_appview( self, diff --git a/src/atdata/atmosphere/lens.py b/src/atdata/atmosphere/lens.py index e888efa..8a65c19 100644 --- a/src/atdata/atmosphere/lens.py +++ b/src/atdata/atmosphere/lens.py @@ -302,18 +302,14 @@ def list_all( Returns: List of lens records. """ - if getattr(self.client, "has_appview", False) is True: - try: - return self._list_via_appview(repo=repo, limit=limit) - except Exception: - from .._logging import get_logger - - get_logger().warning( - "AppView listLenses failed, falling back to client-side", - exc_info=True, - ) + from ._appview import with_appview_fallback - return self.client.list_lenses(repo=repo, limit=limit) + return with_appview_fallback( + lambda: self._list_via_appview(repo=repo, limit=limit), + lambda: self.client.list_lenses(repo=repo, limit=limit), + client=self.client, + operation="listLenses", + ) def _list_via_appview( self, @@ -367,21 +363,17 @@ def find_by_schemas( Returns: List of matching lens records. """ - if getattr(self.client, "has_appview", False) is True: - try: - return self._find_by_schemas_via_appview( - source_schema_uri, target_schema_uri - ) - except Exception: - from .._logging import get_logger - - get_logger().warning( - "AppView searchLenses failed, falling back to client-side", - exc_info=True, - ) - - return self._find_by_schemas_client_side( - source_schema_uri, target_schema_uri, repo + from ._appview import with_appview_fallback + + return with_appview_fallback( + lambda: self._find_by_schemas_via_appview( + source_schema_uri, target_schema_uri + ), + lambda: self._find_by_schemas_client_side( + source_schema_uri, target_schema_uri, repo + ), + client=self.client, + operation="searchLenses", ) def _find_by_schemas_via_appview( diff --git a/src/atdata/atmosphere/records.py b/src/atdata/atmosphere/records.py index b91465f..3b3ba0d 100644 --- a/src/atdata/atmosphere/records.py +++ b/src/atdata/atmosphere/records.py @@ -559,27 +559,24 @@ def get(self, uri: str | AtUri) -> dict: Raises: ValueError: If the record is not a dataset record. """ - if getattr(self.client, "has_appview", False) is True: - try: - return self._get_via_appview(uri) - except Exception: - from .._logging import get_logger - - get_logger().warning( - "AppView getEntry failed, falling back to client-side", - exc_info=True, - ) - - record = self.client.get_record(uri) + from ._appview import with_appview_fallback - expected_type = f"{LEXICON_NAMESPACE}.entry" - if record.get("$type") != expected_type: - raise ValueError( - f"Record at {uri} is not a dataset record. " - f"Expected $type='{expected_type}', got '{record.get('$type')}'" - ) + def _fallback(): + record = self.client.get_record(uri) + expected_type = f"{LEXICON_NAMESPACE}.entry" + if record.get("$type") != expected_type: + raise ValueError( + f"Record at {uri} is not a dataset record. " + f"Expected $type='{expected_type}', got '{record.get('$type')}'" + ) + return record - return record + return with_appview_fallback( + lambda: self._get_via_appview(uri), + _fallback, + client=self.client, + operation="getEntry", + ) def _get_via_appview(self, uri: str | AtUri) -> dict: """Fetch a dataset entry via AppView XRPC query.""" @@ -620,18 +617,14 @@ def list_all( Returns: List of dataset records. """ - if getattr(self.client, "has_appview", False) is True: - try: - return self._list_via_appview(repo=repo, limit=limit) - except Exception: - from .._logging import get_logger - - get_logger().warning( - "AppView listEntries failed, falling back to client-side", - exc_info=True, - ) + from ._appview import with_appview_fallback - return self.client.list_datasets(repo=repo, limit=limit) + return with_appview_fallback( + lambda: self._list_via_appview(repo=repo, limit=limit), + lambda: self.client.list_datasets(repo=repo, limit=limit), + client=self.client, + operation="listEntries", + ) def _list_via_appview( self, @@ -803,31 +796,27 @@ def get_blob_urls(self, uri: str | AtUri) -> list[str]: else: parsed_uri = uri - if getattr(self.client, "has_appview", False) is True: - try: - return self._get_blob_urls_via_appview(str(parsed_uri)) - except Exception: - from .._logging import get_logger - - get_logger().warning( - "AppView resolveBlobs failed, falling back to client-side", - exc_info=True, - ) + from ._appview import with_appview_fallback - blob_entries = self.get_blobs(uri) - did = parsed_uri.authority - - urls = [] - for entry in blob_entries: - # Handle both new blobEntry format and legacy bare blob format - blob = entry.get("blob", entry) - ref = blob.get("ref", {}) - cid = ref.get("$link") if isinstance(ref, dict) else str(ref) - if cid: - url = self.client.get_blob_url(did, cid) - urls.append(url) + def _fallback(): + blob_entries = self.get_blobs(uri) + did = parsed_uri.authority + urls = [] + for entry in blob_entries: + blob = entry.get("blob", entry) + ref = blob.get("ref", {}) + cid = ref.get("$link") if isinstance(ref, dict) else str(ref) + if cid: + url = self.client.get_blob_url(did, cid) + urls.append(url) + return urls - return urls + return with_appview_fallback( + lambda: self._get_blob_urls_via_appview(str(parsed_uri)), + _fallback, + client=self.client, + operation="resolveBlobs", + ) def _get_blob_urls_via_appview(self, uri: str) -> list[str]: """Resolve blob URLs via AppView batch endpoint.""" diff --git a/src/atdata/atmosphere/schema.py b/src/atdata/atmosphere/schema.py index ea70887..e23ac7e 100644 --- a/src/atdata/atmosphere/schema.py +++ b/src/atdata/atmosphere/schema.py @@ -341,18 +341,14 @@ def resolve( Raises: KeyError: If no matching schema is found. """ - if getattr(self.client, "has_appview", False) is True: - try: - return self._resolve_via_appview(handle_or_did, schema_id, version) - except Exception: - from .._logging import get_logger + from ._appview import with_appview_fallback - get_logger().warning( - "AppView schema resolution failed, falling back to client-side", - exc_info=True, - ) - - return self._resolve_client_side(handle_or_did, schema_id, version) + return with_appview_fallback( + lambda: self._resolve_via_appview(handle_or_did, schema_id, version), + lambda: self._resolve_client_side(handle_or_did, schema_id, version), + client=self.client, + operation="schema resolution", + ) def _resolve_via_appview( self, @@ -447,18 +443,14 @@ def list_all( Returns: List of schema records. """ - if getattr(self.client, "has_appview", False) is True: - try: - return self._list_via_appview(repo=repo, limit=limit) - except Exception: - from .._logging import get_logger - - get_logger().warning( - "AppView schema listing failed, falling back to client-side", - exc_info=True, - ) - - return self.client.list_schemas(repo=repo, limit=limit) + from ._appview import with_appview_fallback + + return with_appview_fallback( + lambda: self._list_via_appview(repo=repo, limit=limit), + lambda: self.client.list_schemas(repo=repo, limit=limit), + client=self.client, + operation="schema listing", + ) def _list_via_appview( self, diff --git a/src/atdata/dataset.py b/src/atdata/dataset.py index 3c77535..a4ac333 100644 --- a/src/atdata/dataset.py +++ b/src/atdata/dataset.py @@ -626,7 +626,7 @@ def shard_list(self) -> list[str]: warnings.warn( "shard_list is deprecated, use list_shards() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.list_shards() diff --git a/src/atdata/index/_index.py b/src/atdata/index/_index.py index f0b1ece..32c1594 100644 --- a/src/atdata/index/_index.py +++ b/src/atdata/index/_index.py @@ -463,7 +463,7 @@ def load_schema(self, ref: str) -> Type[Packable]: warnings.warn( "Index.load_schema() is deprecated, use Index.get_schema_type() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.get_schema_type(ref, register=True) @@ -553,15 +553,16 @@ def add_entry( warnings.warn( "Index.add_entry() is deprecated, use Index.insert_dataset()", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) - return self._insert_dataset_to_provider( + repo = self._repos.get("local") + if repo is None: + raise RuntimeError("No local repository configured") + return repo.insert_dataset( ds, name=name, schema_ref=schema_ref, - provider=self._provider, - store=None, metadata=metadata, ) @@ -595,103 +596,6 @@ def get_entry_by_name(self, name: str) -> LocalDatasetEntry: # Index protocol methods - @staticmethod - def _ensure_schema_stored( - schema_ref: str, - sample_type: type, - provider: "IndexProvider", # noqa: F821 - ) -> None: - """Persist the schema definition if not already stored. - - Called during dataset insertion so that ``get_schema_type()`` can - reconstruct the type later without the caller needing to publish - the schema separately. - """ - schema_name, version = _parse_schema_ref(schema_ref) - if provider.get_schema_json(schema_name, version) is None: - record = _build_schema_record(sample_type, version=version) - provider.store_schema(schema_name, version, json.dumps(record)) - - def _insert_dataset_to_provider( - self, - ds: Dataset, - *, - name: str, - schema_ref: str | None = None, - provider: "IndexProvider", # noqa: F821 - store: AbstractDataStore | None = None, - **kwargs, - ) -> LocalDatasetEntry: - """Insert a dataset into a specific provider/store pair. - - This is the internal implementation shared by all local and named - repository inserts. - """ - from atdata._logging import get_logger - - log = get_logger() - metadata = kwargs.get("metadata") - - if store is not None: - prefix = kwargs.get("prefix", name) - cache_local = kwargs.get("cache_local", False) - log.debug( - "_insert_dataset_to_provider: name=%s, store=%s", - name, - type(store).__name__, - ) - - written_urls = store.write_shards( - ds, - prefix=prefix, - cache_local=cache_local, - ) - log.info( - "_insert_dataset_to_provider: %d shard(s) written for %s", - len(written_urls), - name, - ) - - if schema_ref is None: - schema_ref = _schema_ref_from_type(ds.sample_type, version="1.0.0") - - self._ensure_schema_stored(schema_ref, ds.sample_type, provider) - - entry_metadata = metadata if metadata is not None else ds._metadata - entry_metadata = _merge_checksums(entry_metadata, written_urls) - entry = LocalDatasetEntry( - name=name, - schema_ref=schema_ref, - data_urls=written_urls, - metadata=entry_metadata, - ) - else: - # No data store - just index the existing URL - if schema_ref is None: - schema_ref = _schema_ref_from_type(ds.sample_type, version="1.0.0") - - self._ensure_schema_stored(schema_ref, ds.sample_type, provider) - - data_urls = [ds.url] - entry_metadata = metadata if metadata is not None else ds._metadata - - entry = LocalDatasetEntry( - name=name, - schema_ref=schema_ref, - data_urls=data_urls, - metadata=entry_metadata, - ) - - provider.store_entry(entry) - provider.store_label( - name=name, - cid=entry.cid, - version=kwargs.get("version"), - description=kwargs.get("description"), - ) - log.debug("_insert_dataset_to_provider: entry stored for %s", name) - return entry - def insert_dataset( self, ds: Dataset, @@ -860,13 +764,11 @@ def insert_dataset( if repo is None: raise KeyError(f"Unknown repository {backend_key!r} in name {name!r}") - effective_store = data_store or repo.data_store - return self._insert_dataset_to_provider( + return repo.insert_dataset( ds, name=resolved_name, schema_ref=schema_ref, - provider=repo.provider, - store=effective_store, + store=data_store, metadata=metadata, **kwargs, ) @@ -1019,11 +921,7 @@ def write_samples( ds, prefix=resolved_name ) - # If write_shards returned blob refs (e.g. ShardUploadResult), - # use storageBlobs so the PDS retains the uploaded blobs. - # Fall back to storageExternal with AT URIs otherwise. blob_refs = getattr(written_urls, "blob_refs", None) or None - shard_checksums = _extract_blob_checksums(written_urls, blob_refs) effective_metadata = ( metadata @@ -1046,14 +944,13 @@ def write_samples( _checksums=shard_checksums, ) - # Local / named repo path + # Local / named repo path — delegate to Repository repo = self._repos.get(backend_key) - if repo is not None and effective_store is not None: - return self._insert_dataset_to_provider( + if repo is not None: + return repo.insert_dataset( ds, name=resolved_name, schema_ref=schema_ref, - provider=repo.provider, store=effective_store, metadata=metadata, ) @@ -1084,7 +981,7 @@ def write( warnings.warn( "Index.write() is deprecated, use Index.write_samples()", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.write_samples(samples, name=name, **kwargs) @@ -1535,7 +1432,7 @@ def decode_schema(self, ref: str) -> Type[Packable]: warnings.warn( "Index.decode_schema() is deprecated, use Index.get_schema_type() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.get_schema_type(ref, register=False) @@ -1550,7 +1447,7 @@ def decode_schema_as(self, ref: str, type_hint: type[T]) -> type[T]: warnings.warn( "Index.decode_schema_as() is deprecated, use Index.get_schema_type() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) from typing import cast @@ -1609,7 +1506,7 @@ def promote_entry( warnings.warn( "Index.promote_entry() is deprecated, use Index.insert_dataset()", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) from atdata.promote import _find_or_publish_schema @@ -1702,7 +1599,7 @@ def promote_dataset( warnings.warn( "Index.promote_dataset() is deprecated, use Index.insert_dataset()", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) from atdata.promote import _find_or_publish_schema diff --git a/src/atdata/local/_repo_legacy.py b/src/atdata/local/_repo_legacy.py index 10c4fd7..7e170d1 100644 --- a/src/atdata/local/_repo_legacy.py +++ b/src/atdata/local/_repo_legacy.py @@ -70,7 +70,7 @@ def __init__( " store = S3DataStore(credentials, bucket='my-bucket')\n" " index = Index(redis=redis, data_store=store)\n" " entry = index.insert_dataset(ds, name='my-dataset')", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) diff --git a/src/atdata/providers/_redis.py b/src/atdata/providers/_redis.py index 0577f65..b0b1169 100644 --- a/src/atdata/providers/_redis.py +++ b/src/atdata/providers/_redis.py @@ -7,6 +7,7 @@ from __future__ import annotations +from datetime import datetime, timezone from typing import Iterator import msgpack @@ -73,10 +74,24 @@ def get_entry_by_name(self, name: str) -> "LocalDatasetEntry": # noqa: F821 def iter_entries(self) -> Iterator["LocalDatasetEntry"]: # noqa: F821 prefix = f"{_KEY_DATASET_ENTRY}:" + # Collect keys first, then batch-fetch with a pipeline to avoid + # N+1 round-trips (one per entry). + keys: list[str] = [] for key in self._redis.scan_iter(match=f"{prefix}*"): key_str = key.decode("utf-8") if isinstance(key, bytes) else key - cid = key_str[len(prefix) :] - yield self.get_entry_by_cid(cid) + keys.append(key_str) + + if not keys: + return + + pipe = self._redis.pipeline(transaction=False) + for key_str in keys: + pipe.hgetall(key_str) + results = pipe.execute() + + for raw_data in results: + if raw_data: + yield _entry_from_redis_hash(raw_data) # ------------------------------------------------------------------ # Schema operations @@ -135,7 +150,12 @@ def store_label( ) -> None: ver_key = version or "" redis_key = f"Label:{name}@{ver_key}" - data: dict[str, str] = {"cid": cid, "name": name, "version": ver_key} + data: dict[str, str] = { + "cid": cid, + "name": name, + "version": ver_key, + "created_at": datetime.now(timezone.utc).isoformat(), + } if description is not None: data["description"] = description self._redis.hset(redis_key, mapping=data) # type: ignore[arg-type] @@ -156,10 +176,13 @@ def get_label( } return (raw_typed["cid"], version) - # No version specified — scan for all labels with this name, pick latest + # No version specified — scan for all labels with this name, + # pick the most recently created one (by created_at timestamp). + # Old labels without created_at sort last. prefix = f"Label:{name}@" best_cid: str | None = None best_ver: str | None = None + best_ts: str = "" for key in self._redis.scan_iter(match=f"{prefix}*"): raw = self._redis.hgetall(key) if not raw: @@ -170,11 +193,12 @@ def get_label( ) for k, v in raw.items() } - # Pick any match; Redis doesn't have created_at ordering so we - # just return the last one found (consistent with scan order). - best_cid = raw_typed["cid"] - ver = raw_typed.get("version", "") - best_ver = ver if ver else None + ts = raw_typed.get("created_at", "") + if best_cid is None or ts > best_ts: + best_cid = raw_typed["cid"] + ver = raw_typed.get("version", "") + best_ver = ver if ver else None + best_ts = ts if best_cid is None: raise KeyError(f"No label with name: {name!r}") diff --git a/src/atdata/providers/_sqlite.py b/src/atdata/providers/_sqlite.py index a741a4d..f5f7407 100644 --- a/src/atdata/providers/_sqlite.py +++ b/src/atdata/providers/_sqlite.py @@ -59,6 +59,12 @@ CREATE INDEX IF NOT EXISTS idx_labels_cid ON labels(cid); + +CREATE INDEX IF NOT EXISTS idx_schemas_name_version + ON schemas(name, version); + +CREATE INDEX IF NOT EXISTS idx_labels_name_created + ON labels(name, created_at DESC); """ diff --git a/src/atdata/repository.py b/src/atdata/repository.py index 65c9a59..1fa129b 100644 --- a/src/atdata/repository.py +++ b/src/atdata/repository.py @@ -24,12 +24,14 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterator, Optional, TYPE_CHECKING +from typing import Any, Iterable, Iterator, Optional, TYPE_CHECKING from ._protocols import AbstractDataStore if TYPE_CHECKING: from .providers._base import IndexProvider + from .index._entry import LocalDatasetEntry + from .dataset import Dataset @dataclass @@ -62,6 +64,179 @@ class Repository: provider: IndexProvider data_store: AbstractDataStore | None = None + def insert_dataset( + self, + ds: Dataset, + *, + name: str, + schema_ref: str | None = None, + store: AbstractDataStore | None = None, + **kwargs: Any, + ) -> LocalDatasetEntry: + """Insert a dataset into this repository's provider/store. + + Args: + ds: The Dataset to register. + name: Human-readable name for the dataset. + schema_ref: Optional schema reference. Auto-generated if ``None``. + store: Explicit data store override. Falls back to + ``self.data_store`` if ``None``. + **kwargs: Extra options forwarded to provider (metadata, version, + description, prefix, cache_local). + + Returns: + LocalDatasetEntry for the inserted dataset. + """ + from atdata._logging import get_logger + from atdata.index._schema import ( + _schema_ref_from_type, + ) + from atdata.index._entry import LocalDatasetEntry as _LDE + + log = get_logger() + effective_store = store or self.data_store + metadata = kwargs.get("metadata") + + if effective_store is not None: + prefix = kwargs.get("prefix", name) + cache_local = kwargs.get("cache_local", False) + log.debug( + "Repository.insert_dataset: name=%s, store=%s", + name, + type(effective_store).__name__, + ) + + written_urls = effective_store.write_shards( + ds, + prefix=prefix, + cache_local=cache_local, + ) + log.info( + "Repository.insert_dataset: %d shard(s) written for %s", + len(written_urls), + name, + ) + + if schema_ref is None: + schema_ref = _schema_ref_from_type(ds.sample_type, version="1.0.0") + + self._ensure_schema_stored(schema_ref, ds.sample_type) + + entry_metadata = metadata if metadata is not None else ds._metadata + from atdata.index._index import _merge_checksums + + entry_metadata = _merge_checksums(entry_metadata, written_urls) + entry = _LDE( + name=name, + schema_ref=schema_ref, + data_urls=written_urls, + metadata=entry_metadata, + ) + else: + if schema_ref is None: + schema_ref = _schema_ref_from_type(ds.sample_type, version="1.0.0") + + self._ensure_schema_stored(schema_ref, ds.sample_type) + + data_urls = [ds.url] + entry_metadata = metadata if metadata is not None else ds._metadata + + entry = _LDE( + name=name, + schema_ref=schema_ref, + data_urls=data_urls, + metadata=entry_metadata, + ) + + self.provider.store_entry(entry) + self.provider.store_label( + name=name, + cid=entry.cid, + version=kwargs.get("version"), + description=kwargs.get("description"), + ) + log.debug("Repository.insert_dataset: entry stored for %s", name) + return entry + + def write_samples( + self, + samples: Iterable, + *, + name: str, + schema_ref: str | None = None, + maxcount: int = 10_000, + maxsize: int | None = None, + manifest: bool = False, + data_store: AbstractDataStore | None = None, + metadata: dict | None = None, + **kwargs: Any, + ) -> LocalDatasetEntry: + """Write samples and create an index entry in one step. + + Serialises samples to WebDataset tar files, writes them through + this repository's store, and creates an index entry. + + When the repository has no ``data_store`` configured and no explicit + *data_store* is provided, a ``LocalDiskStore`` is created + automatically at ``~/.atdata/data/``. + + Args: + samples: Iterable of ``Packable`` samples. Must be non-empty. + name: Dataset name. + schema_ref: Optional schema reference. Auto-generated if ``None``. + maxcount: Max samples per shard. + maxsize: Max bytes per shard. + manifest: Write per-shard manifest sidecar files. + data_store: Explicit data store override. + metadata: Optional metadata dict. + **kwargs: Extra options forwarded to ``insert_dataset``. + + Returns: + LocalDatasetEntry for the created dataset. + """ + import tempfile + + from atdata.dataset import write_samples as _write_samples + + effective_store = data_store or self.data_store + if effective_store is None: + from atdata.stores._disk import LocalDiskStore + + effective_store = LocalDiskStore() + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) / "data.tar" + ds = _write_samples( + samples, + tmp_path, + maxcount=maxcount, + maxsize=maxsize, + manifest=manifest, + ) + return self.insert_dataset( + ds, + name=name, + schema_ref=schema_ref, + store=effective_store, + metadata=metadata, + **kwargs, + ) + + def _ensure_schema_stored( + self, + schema_ref: str, + sample_type: type, + ) -> None: + """Persist the schema definition if not already stored.""" + import json + + from atdata.index._schema import _parse_schema_ref, _build_schema_record + + schema_name, version = _parse_schema_ref(schema_ref) + if self.provider.get_schema_json(schema_name, version) is None: + record = _build_schema_record(sample_type, version=version) + self.provider.store_schema(schema_name, version, json.dumps(record)) + def create_repository( provider: str = "sqlite", @@ -122,32 +297,20 @@ def __init__( data_store: Optional[AbstractDataStore] = None, ) -> None: from .atmosphere.client import Atmosphere + from .atmosphere.schema import SchemaPublisher, SchemaLoader + from .atmosphere.records import DatasetPublisher, DatasetLoader + from .atmosphere.labels import LabelPublisher, LabelLoader if not isinstance(client, Atmosphere): raise TypeError(f"Expected Atmosphere, got {type(client).__name__}") self.client: Atmosphere = client self._data_store = data_store - self._schema_publisher: Any = None - self._schema_loader: Any = None - self._dataset_publisher: Any = None - self._dataset_loader: Any = None - self._label_publisher: Any = None - self._label_loader: Any = None - - def _ensure_loaders(self) -> None: - """Lazily create publishers/loaders on first use.""" - if self._schema_loader is not None: - return - from .atmosphere.schema import SchemaPublisher, SchemaLoader - from .atmosphere.records import DatasetPublisher, DatasetLoader - from .atmosphere.labels import LabelPublisher, LabelLoader - - self._schema_publisher = SchemaPublisher(self.client) - self._schema_loader = SchemaLoader(self.client) - self._dataset_publisher = DatasetPublisher(self.client) - self._dataset_loader = DatasetLoader(self.client) - self._label_publisher = LabelPublisher(self.client) - self._label_loader = LabelLoader(self.client) + self._schema_publisher = SchemaPublisher(client) + self._schema_loader = SchemaLoader(client) + self._dataset_publisher = DatasetPublisher(client) + self._dataset_loader = DatasetLoader(client) + self._label_publisher = LabelPublisher(client) + self._label_loader = LabelLoader(client) @property def data_store(self) -> Optional[AbstractDataStore]: @@ -168,7 +331,7 @@ def get_dataset(self, ref: str) -> Any: Raises: ValueError: If record is not a dataset. """ - self._ensure_loaders() + from .atmosphere import AtmosphereIndexEntry record = self._dataset_loader.get(ref) @@ -183,7 +346,7 @@ def list_datasets(self, repo: str | None = None) -> list[Any]: Returns: List of AtmosphereIndexEntry for each dataset. """ - self._ensure_loaders() + from .atmosphere import AtmosphereIndexEntry records = self._dataset_loader.list_all(repo=repo) @@ -201,7 +364,7 @@ def iter_datasets(self, repo: str | None = None) -> Iterator[Any]: Yields: AtmosphereIndexEntry for each dataset. """ - self._ensure_loaders() + from .atmosphere import AtmosphereIndexEntry records = self._dataset_loader.list_all(repo=repo) @@ -243,7 +406,7 @@ def insert_dataset( Returns: AtmosphereIndexEntry for the inserted dataset. """ - self._ensure_loaders() + from .atmosphere import AtmosphereIndexEntry if blob_refs is not None or data_urls is not None: @@ -339,7 +502,7 @@ def resolve_label( Raises: KeyError: If no matching label is found. """ - self._ensure_loaders() + return self._label_loader.resolve(handle_or_did, name, version) # -- Schema operations -- @@ -361,7 +524,7 @@ def publish_schema( Returns: AT URI of the schema record. """ - self._ensure_loaders() + uri = self._schema_publisher.publish( sample_type, version=version, @@ -380,7 +543,7 @@ def get_schema(self, ref: str) -> dict: Returns: Schema record dictionary. """ - self._ensure_loaders() + return self._schema_loader.get(ref) def list_schemas(self, repo: str | None = None) -> list[dict]: @@ -392,7 +555,7 @@ def list_schemas(self, repo: str | None = None) -> list[dict]: Returns: List of schema records as dictionaries. """ - self._ensure_loaders() + records = self._schema_loader.list_all(repo=repo) return [rec.get("value", rec) for rec in records] @@ -402,7 +565,7 @@ def iter_schemas(self) -> Iterator[dict]: Yields: Schema records as dictionaries. """ - self._ensure_loaders() + records = self._schema_loader.list_all() for rec in records: yield rec.get("value", rec) @@ -431,7 +594,7 @@ def decode_schema(self, ref: str) -> type: warnings.warn( "Repository.decode_schema() is deprecated, use Repository.get_schema_type() instead", - DeprecationWarning, + FutureWarning, # Removal: v1.0 stacklevel=2, ) return self.get_schema_type(ref) diff --git a/tests/test_atmosphere.py b/tests/test_atmosphere.py index 4893ead..49f42eb 100644 --- a/tests/test_atmosphere.py +++ b/tests/test_atmosphere.py @@ -178,7 +178,7 @@ def test_parse_atdata_namespace(self): # ============================================================================= -@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::FutureWarning") class TestFieldType: """Tests for deprecated FieldType shim dataclass.""" @@ -229,7 +229,7 @@ def test_array_type(self): # ============================================================================= -@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::FutureWarning") class TestFieldDef: """Tests for deprecated FieldDef shim dataclass.""" @@ -482,7 +482,7 @@ def test_from_record_new_key_takes_precedence_over_old(self): # ============================================================================= -@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::FutureWarning") class TestStorageLocation: """Tests for deprecated StorageLocation shim dataclass.""" @@ -3594,7 +3594,7 @@ def test_data_urls_resolves_storage_blobs(self, mock_resolve): mock_resolve.assert_called_once_with("did:plc:testdid") -@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::FutureWarning") class TestAtmosphereIndex: """Tests for AtmosphereIndex unified interface.""" diff --git a/tests/test_atmosphere_label_integration.py b/tests/test_atmosphere_label_integration.py index cdd0a64..e5c703d 100644 --- a/tests/test_atmosphere_label_integration.py +++ b/tests/test_atmosphere_label_integration.py @@ -180,7 +180,7 @@ def test_insert_dataset_label_without_version(self, mock_atmo): assert "version" not in labels[0] def test_real_backend_initializes_label_publisher(self): - """_AtmosphereBackend._ensure_loaders initializes label publisher/loader.""" + """_AtmosphereBackend eagerly initializes label publisher/loader.""" from atdata.atmosphere.client import Atmosphere mock_sdk = MagicMock() @@ -192,7 +192,6 @@ def test_real_backend_initializes_label_publisher(self): atmo._login("test.social", "pass") backend = _AtmosphereBackend(atmo) - backend._ensure_loaders() assert backend._label_publisher is not None assert backend._label_loader is not None @@ -485,7 +484,6 @@ def test_resolve_label_delegates_to_loader(self): atmo._login("t.social", "pass") backend = _AtmosphereBackend(atmo) - backend._ensure_loaders() # Mock the label loader's resolve method backend._label_loader = MagicMock() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index be789c7..394e30a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -448,7 +448,7 @@ class BatchTypeSample: assert batch_type.__origin__ == atdata.SampleBatch -@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::FutureWarning") def test_dataset_shard_list_property(tmp_path): """Test Dataset.shard_list property returns list of shard URLs.""" diff --git a/tests/test_index_write.py b/tests/test_index_write.py index c6558e8..0dcfdf8 100644 --- a/tests/test_index_write.py +++ b/tests/test_index_write.py @@ -252,7 +252,7 @@ def test_promote_entry_calls_atmosphere(self, sqlite_provider, tmp_path: Path): # --------------------------------------------------------------------------- -@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::FutureWarning") class TestIndexPromoteDataset: """Tests for deprecated Index.promote_dataset().""" @@ -750,7 +750,7 @@ def test_write_emits_deprecation(self, index): warnings.simplefilter("always") index.write(samples, name="dep-write") - dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] assert any("write_samples" in str(dw.message) for dw in dep_warnings) def test_add_entry_emits_deprecation(self, index, tmp_path): @@ -761,7 +761,7 @@ def test_add_entry_emits_deprecation(self, index, tmp_path): warnings.simplefilter("always") index.add_entry(ds, name="dep-add") - dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] assert any("insert_dataset" in str(dw.message) for dw in dep_warnings) def test_promote_entry_emits_deprecation(self, index): @@ -772,7 +772,7 @@ def test_promote_entry_emits_deprecation(self, index): except (ValueError, KeyError): pass - dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] assert any("insert_dataset" in str(dw.message) for dw in dep_warnings) def test_promote_dataset_emits_deprecation(self, index): @@ -784,5 +784,5 @@ def test_promote_dataset_emits_deprecation(self, index): except (ValueError, KeyError): pass - dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] assert any("insert_dataset" in str(dw.message) for dw in dep_warnings) diff --git a/tests/test_local.py b/tests/test_local.py index c412f44..7bcfd6a 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -597,7 +597,7 @@ def test_index_list_datasets(clean_redis): # Note: Repo is deprecated; these tests verify backwards compatibility -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_init_no_s3(): """Test creating a Repo without S3 credentials.""" repo = atlocal.Repo() @@ -610,7 +610,7 @@ def test_repo_init_no_s3(): assert isinstance(repo.index, atlocal.Index) -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_init_with_s3_dict(): """Test creating a Repo with S3 credentials as a dictionary.""" creds = { @@ -628,7 +628,7 @@ def test_repo_init_with_s3_dict(): assert repo.hive_bucket == "test-bucket" -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_init_with_s3_path(tmp_path): """Test creating a Repo with S3 credentials from a .env file.""" env_file = tmp_path / ".env" @@ -647,7 +647,7 @@ def test_repo_init_with_s3_path(tmp_path): assert repo.hive_bucket == "test-bucket" -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_init_s3_without_hive_path(): """Test that creating a Repo with S3 but no hive_path raises ValueError.""" creds = { @@ -660,7 +660,7 @@ def test_repo_init_s3_without_hive_path(): atlocal.Repo(s3_credentials=creds) -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_init_hive_path_parsing(): """Test that hive_path is correctly parsed to extract bucket name.""" creds = { @@ -675,7 +675,7 @@ def test_repo_init_hive_path_parsing(): assert repo.hive_path == Path("my-bucket/path/to/datasets") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_init_with_custom_redis(): """Test creating a Repo with a custom Redis connection.""" custom_redis = Redis() @@ -688,7 +688,7 @@ def test_repo_init_with_custom_redis(): # Repo tests - Insert functionality -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_without_s3(): """Test that inserting a dataset without S3 configured raises ValueError.""" repo = atlocal.Repo() @@ -700,7 +700,7 @@ def test_repo_insert_without_s3(): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_single_shard(mock_s3, clean_redis, sample_dataset): """Test inserting a small dataset that fits in a single shard.""" repo = atlocal.Repo( @@ -726,7 +726,7 @@ def test_repo_insert_single_shard(mock_s3, clean_redis, sample_dataset): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_multiple_shards(mock_s3, clean_redis, tmp_path): """Test inserting a large dataset that spans multiple shards.""" ds = make_simple_dataset(tmp_path, num_samples=50, name="large") @@ -747,7 +747,7 @@ def test_repo_insert_multiple_shards(mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_with_metadata(mock_s3, clean_redis, tmp_path): """Test inserting a dataset with metadata.""" ds = make_simple_dataset(tmp_path, num_samples=5) @@ -768,7 +768,7 @@ def test_repo_insert_with_metadata(mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_without_metadata(mock_s3, clean_redis, tmp_path): """Test inserting a dataset without metadata.""" ds = make_simple_dataset(tmp_path, num_samples=5) @@ -786,7 +786,7 @@ def test_repo_insert_without_metadata(mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_cache_local_false(mock_s3, clean_redis, sample_dataset): """Test inserting with cache_local=False (direct S3 write).""" repo = atlocal.Repo( @@ -806,7 +806,7 @@ def test_repo_insert_cache_local_false(mock_s3, clean_redis, sample_dataset): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_cache_local_true(mock_s3, clean_redis, sample_dataset): """Test inserting with cache_local=True (local cache then copy).""" repo = atlocal.Repo( @@ -826,7 +826,7 @@ def test_repo_insert_cache_local_true(mock_s3, clean_redis, sample_dataset): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_creates_index_entry(mock_s3, clean_redis, sample_dataset): """Test that insert() creates a valid index entry.""" repo = atlocal.Repo( @@ -848,7 +848,7 @@ def test_repo_insert_creates_index_entry(mock_s3, clean_redis, sample_dataset): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_cid_generation(mock_s3, clean_redis, sample_dataset): """Test that insert() generates unique CIDs for each dataset.""" repo = atlocal.Repo( @@ -867,7 +867,7 @@ def test_repo_insert_cid_generation(mock_s3, clean_redis, sample_dataset): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_empty_dataset(mock_s3, clean_redis, tmp_path): """Test inserting an empty dataset.""" dataset_path = tmp_path / "empty-dataset-000000.tar" @@ -889,7 +889,7 @@ def test_repo_insert_empty_dataset(mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_preserves_sample_type(mock_s3, clean_redis, sample_dataset): """Test that the returned Dataset preserves the original sample type.""" repo = atlocal.Repo( @@ -906,7 +906,7 @@ def test_repo_insert_preserves_sample_type(mock_s3, clean_redis, sample_dataset) @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_with_shard_writer_kwargs(mock_s3, clean_redis, tmp_path): """Test that insert() passes additional kwargs to ShardWriter.""" ds = make_simple_dataset(tmp_path, num_samples=30, name="large") @@ -923,7 +923,7 @@ def test_repo_insert_with_shard_writer_kwargs(mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_insert_numpy_arrays(mock_s3, clean_redis, tmp_path): """Test inserting a dataset containing samples with numpy arrays.""" ds = make_array_dataset(tmp_path, num_samples=3, array_shape=(10, 10)) @@ -945,7 +945,7 @@ def test_repo_insert_numpy_arrays(mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_repo_index_integration(mock_s3, clean_redis, sample_dataset): """Test that Repo and Index work together correctly.""" repo = atlocal.Repo( @@ -964,7 +964,7 @@ def test_repo_index_integration(mock_s3, clean_redis, sample_dataset): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_multiple_datasets_same_type(mock_s3, clean_redis, sample_dataset): """Test inserting multiple datasets of the same sample type.""" repo = atlocal.Repo( @@ -989,7 +989,7 @@ def test_multiple_datasets_same_type(mock_s3, clean_redis, sample_dataset): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") +@pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_multiple_datasets_different_types(mock_s3, clean_redis, tmp_path): """Test inserting datasets with different sample types.""" simple_ds = make_simple_dataset(tmp_path, num_samples=3, name="simple") diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 0274a11..2ca6a9d 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -83,7 +83,7 @@ def test_local_index_has_required_methods(self, tmp_path): with pytest.raises(KeyError): index.get_dataset("nonexistent") - @pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_atmosphere_index_has_required_methods(self): """AtmosphereIndex should have all required Index methods.""" mock_client = Mock() diff --git a/tests/test_repository.py b/tests/test_repository.py index 6007e45..f842050 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -375,5 +375,5 @@ def test_deprecation_warning(self): mock_client.is_authenticated = True mock_client.did = "did:plc:test" - with pytest.warns(DeprecationWarning, match="AtmosphereIndex is deprecated"): + with pytest.warns(FutureWarning, match="AtmosphereIndex is deprecated"): AtmosphereIndex(mock_client) diff --git a/tests/test_repository_coverage.py b/tests/test_repository_coverage.py index b4879e4..2157b71 100644 --- a/tests/test_repository_coverage.py +++ b/tests/test_repository_coverage.py @@ -64,32 +64,18 @@ def test_atmosphere_backend_data_store_none(backend) -> None: # --------------------------------------------------------------------------- -def test_ensure_loaders_lazy_init(backend) -> None: - """Loaders/publishers are None until _ensure_loaders is called.""" - assert backend._schema_loader is None - assert backend._dataset_loader is None - assert backend._label_publisher is None - assert backend._label_loader is None - - with ( - patch("atdata.atmosphere.schema.SchemaPublisher") as MockSP, - patch("atdata.atmosphere.schema.SchemaLoader") as MockSL, - patch("atdata.atmosphere.records.DatasetPublisher") as MockDP, - patch("atdata.atmosphere.records.DatasetLoader") as MockDL, - patch("atdata.atmosphere.labels.LabelPublisher") as MockLP, - patch("atdata.atmosphere.labels.LabelLoader") as MockLL, - ): - backend._ensure_loaders() - - MockSP.assert_called_once_with(backend.client) - MockSL.assert_called_once_with(backend.client) - MockDP.assert_called_once_with(backend.client) - MockDL.assert_called_once_with(backend.client) - MockLP.assert_called_once_with(backend.client) - MockLL.assert_called_once_with(backend.client) - - # Second call is a no-op (already initialised) - backend._ensure_loaders() +def test_eager_init_creates_loaders(backend) -> None: + """All publishers/loaders are created eagerly in __init__.""" + from atdata.atmosphere.schema import SchemaPublisher, SchemaLoader + from atdata.atmosphere.records import DatasetPublisher, DatasetLoader + from atdata.atmosphere.labels import LabelPublisher, LabelLoader + + assert isinstance(backend._schema_publisher, SchemaPublisher) + assert isinstance(backend._schema_loader, SchemaLoader) + assert isinstance(backend._dataset_publisher, DatasetPublisher) + assert isinstance(backend._dataset_loader, DatasetLoader) + assert isinstance(backend._label_publisher, LabelPublisher) + assert isinstance(backend._label_loader, LabelLoader) # --------------------------------------------------------------------------- @@ -98,7 +84,7 @@ def test_ensure_loaders_lazy_init(backend) -> None: def _patch_loaders(backend): - """Patch _ensure_loaders to inject mocks directly.""" + """Replace eagerly-created loaders with mocks.""" backend._schema_publisher = MagicMock() backend._schema_loader = MagicMock() backend._dataset_publisher = MagicMock() diff --git a/tests/test_sources.py b/tests/test_sources.py index 3b81a46..2707a63 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -36,13 +36,13 @@ def test_conforms_to_protocol(self): source = URLSource("http://example.com/data.tar") assert isinstance(source, DataSource) - @pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_shard_list_single_url(self): """shard_list returns single URL unchanged.""" source = URLSource("http://example.com/data.tar") assert source.shard_list == ["http://example.com/data.tar"] - @pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_shard_list_brace_expansion(self): """shard_list expands brace patterns.""" source = URLSource("data-{000..002}.tar") @@ -52,7 +52,7 @@ def test_shard_list_brace_expansion(self): "data-002.tar", ] - @pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_shard_list_complex_brace_pattern(self): """shard_list handles complex brace patterns.""" source = URLSource("s3://bucket/{train,test}-{00..01}.tar") @@ -125,7 +125,7 @@ def test_conforms_to_protocol(self): source = S3Source(bucket="test", keys=["data.tar"]) assert isinstance(source, DataSource) - @pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_shard_list(self): """shard_list returns S3 URIs.""" source = S3Source(bucket="my-bucket", keys=["a.tar", "b.tar"]) @@ -462,7 +462,7 @@ def test_open_shard_invalid_format(self): class TestDatasetWithDataSource: """Integration tests for Dataset with different DataSource types.""" - @pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_dataset_accepts_url_source(self, tmp_path): """Dataset can be created with URLSource.""" tar_path = tmp_path / "test.tar" diff --git a/tests/test_workflow_atmosphere.py b/tests/test_workflow_atmosphere.py index 418051c..c45b675 100644 --- a/tests/test_workflow_atmosphere.py +++ b/tests/test_workflow_atmosphere.py @@ -218,7 +218,7 @@ def test_get_schema_by_uri(self, authenticated_client, mock_atproto_client): assert schema["version"] == "2.0.0" -@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::FutureWarning") class TestAtmosphereIndex: """Tests for AtmosphereIndex Index protocol compliance.""" diff --git a/tests/test_workflow_cross_backend.py b/tests/test_workflow_cross_backend.py index 6dab8d6..ee84cf2 100644 --- a/tests/test_workflow_cross_backend.py +++ b/tests/test_workflow_cross_backend.py @@ -84,7 +84,7 @@ def atmosphere_index(authenticated_atmosphere_client): import warnings with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", FutureWarning) return AtmosphereIndex(authenticated_atmosphere_client) diff --git a/tests/test_workflow_local.py b/tests/test_workflow_local.py index 330be1e..aee3e47 100644 --- a/tests/test_workflow_local.py +++ b/tests/test_workflow_local.py @@ -120,7 +120,7 @@ class TestFullRepoWorkflow: @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_init_publish_schema_insert_query(self, mock_s3, clean_redis, tmp_path): """Full workflow: init repo → publish schema → insert → query entry.""" # Initialize repo @@ -151,7 +151,7 @@ def test_init_publish_schema_insert_query(self, mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_multiple_datasets_same_schema(self, mock_s3, clean_redis, tmp_path): """Insert multiple datasets with same schema type.""" repo = atlocal.Repo( @@ -184,7 +184,7 @@ def test_multiple_datasets_same_schema(self, mock_s3, clean_redis, tmp_path): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_different_schema_types(self, mock_s3, clean_redis, tmp_path): """Insert datasets with different schema types.""" repo = atlocal.Repo( @@ -447,7 +447,7 @@ class TestMetadataPersistence: @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_metadata_preserved_through_insert(self, mock_s3, clean_redis, tmp_path): """Metadata should be preserved when inserting dataset.""" repo = atlocal.Repo( @@ -519,7 +519,7 @@ class TestCacheLocalModes: @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_cache_local_true_produces_valid_entry( self, mock_s3, clean_redis, tmp_path ): @@ -539,7 +539,7 @@ def test_cache_local_true_produces_valid_entry( @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_cache_local_false_produces_valid_entry( self, mock_s3, clean_redis, tmp_path ): @@ -559,7 +559,7 @@ def test_cache_local_false_produces_valid_entry( @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_both_modes_produce_same_structure(self, mock_s3, clean_redis, tmp_path): """Both cache modes should produce entries with same structure.""" repo = atlocal.Repo( @@ -636,7 +636,7 @@ class TestMultiShardStorage: @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_large_dataset_creates_multiple_shards( self, mock_s3, clean_redis, tmp_path ): @@ -668,7 +668,7 @@ def test_large_dataset_creates_multiple_shards( @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.filterwarnings("ignore:coroutine.*was never awaited:RuntimeWarning") - @pytest.mark.filterwarnings("ignore:Repo is deprecated:DeprecationWarning") + @pytest.mark.filterwarnings("ignore:Repo is deprecated:FutureWarning") def test_single_shard_no_brace_notation(self, mock_s3, clean_redis, tmp_path): """Small dataset should result in single shard without brace notation.""" repo = atlocal.Repo( From 1fadf1c1a252ff6ad81df20ce2365431d6f93d1f Mon Sep 17 00:00:00 2001 From: Maxine Levesque <220467675+maxine-at-forecast@users.noreply.github.com> Date: Wed, 1 Apr 2026 15:18:09 -0700 Subject: [PATCH 2/2] style: fix ruff formatting in client.py and test_index_write.py Co-Authored-By: Claude Opus 4.6 --- src/atdata/atmosphere/client.py | 8 ++++++-- tests/test_index_write.py | 16 ++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/atdata/atmosphere/client.py b/src/atdata/atmosphere/client.py index 422e987..d49fbbe 100644 --- a/src/atdata/atmosphere/client.py +++ b/src/atdata/atmosphere/client.py @@ -1088,7 +1088,9 @@ def create( validate: bool = False, ) -> AtUri: """Create a record. See :meth:`Atmosphere.create_record`.""" - return self._atmo.create_record(collection, record, rkey=rkey, validate=validate) + return self._atmo.create_record( + collection, record, rkey=rkey, validate=validate + ) def put( self, @@ -1126,7 +1128,9 @@ def list( cursor: Optional[str] = None, ) -> tuple[list[dict], Optional[str]]: """List records in a collection. See :meth:`Atmosphere.list_records`.""" - return self._atmo.list_records(collection, repo=repo, limit=limit, cursor=cursor) + return self._atmo.list_records( + collection, repo=repo, limit=limit, cursor=cursor + ) class BlobOps: diff --git a/tests/test_index_write.py b/tests/test_index_write.py index 0dcfdf8..22efe3a 100644 --- a/tests/test_index_write.py +++ b/tests/test_index_write.py @@ -750,7 +750,9 @@ def test_write_emits_deprecation(self, index): warnings.simplefilter("always") index.write(samples, name="dep-write") - dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] + dep_warnings = [ + x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning)) + ] assert any("write_samples" in str(dw.message) for dw in dep_warnings) def test_add_entry_emits_deprecation(self, index, tmp_path): @@ -761,7 +763,9 @@ def test_add_entry_emits_deprecation(self, index, tmp_path): warnings.simplefilter("always") index.add_entry(ds, name="dep-add") - dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] + dep_warnings = [ + x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning)) + ] assert any("insert_dataset" in str(dw.message) for dw in dep_warnings) def test_promote_entry_emits_deprecation(self, index): @@ -772,7 +776,9 @@ def test_promote_entry_emits_deprecation(self, index): except (ValueError, KeyError): pass - dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] + dep_warnings = [ + x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning)) + ] assert any("insert_dataset" in str(dw.message) for dw in dep_warnings) def test_promote_dataset_emits_deprecation(self, index): @@ -784,5 +790,7 @@ def test_promote_dataset_emits_deprecation(self, index): except (ValueError, KeyError): pass - dep_warnings = [x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning))] + dep_warnings = [ + x for x in w if issubclass(x.category, (DeprecationWarning, FutureWarning)) + ] assert any("insert_dataset" in str(dw.message) for dw in dep_warnings)