diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c06e62b7..079ca28b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -52,10 +52,7 @@ if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import ( - SpatialDataContainerFormatType, - SpatialDataFormatType, - ) + from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType # schema for elements Label2D_s = Labels2DModel() @@ -241,9 +238,7 @@ def get_annotated_regions(table: AnnData) -> list[str]: ------- The annotated regions. """ - from spatialdata.models.models import ( - _get_region_metadata_from_region_key_column, - ) + from spatialdata.models.models import _get_region_metadata_from_region_key_column return _get_region_metadata_from_region_key_column(table) @@ -695,18 +690,14 @@ def _filter_tables( continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import ( - _filter_table_by_element_names, - ) + from spatialdata._core.query.relational_query import _filter_table_by_element_names assert element_names is not None table = _filter_table_by_element_names(table, element_names) if table is not None and len(table) != 0: tables[table_name] = table elif by == "elements": - from spatialdata._core.query.relational_query import ( - _filter_table_by_elements, - ) + from spatialdata._core.query.relational_query import _filter_table_by_elements assert elements_dict is not None table = _filter_table_by_elements(table, elements_dict=elements_dict) @@ -731,10 +722,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: The method does not allow to rename a coordinate system into an existing one, unless the existing one is also renamed in the same call. """ - from spatialdata.transformations.operations import ( - get_transformation, - set_transformation, - ) + from spatialdata.transformations.operations import get_transformation, set_transformation # check that the rename_dict is valid old_names = self.coordinate_systems @@ -1110,7 +1098,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1215,15 +1203,12 @@ def _write_element( ) root_group, element_type_group, element_group = _get_groups_for_element( - zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False - ) - from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, + zarr_path=zarr_container_path, + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table from spatialdata._io.format import _parse_formats if parsed_formats is None: @@ -1270,7 +1255,7 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1548,7 +1533,10 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) from spatialdata._io._utils import overwrite_channel_names @@ -1599,19 +1587,18 @@ def write_transformations(self, element_name: str | None = None) -> None: ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_raster from spatialdata._io.format import RasterFormats raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] overwrite_coordinate_transformations_raster( - group=element_group, axes=axes, transformations=transformations, raster_format=raster_format + group=element_group, + axes=axes, + transformations=transformations, + raster_format=raster_format, ) elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster overwrite_coordinate_transformations_non_raster( group=element_group, @@ -1830,6 +1817,7 @@ def read( file_path: str | Path | UPath | zarr.Group, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1842,6 +1830,11 @@ def read( The elements to read (images, labels, points, shapes, table). If None, all elements are read. reconsolidate_metadata If the consolidated metadata store got corrupted this can lead to errors when trying to read the data. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. Returns ------- @@ -1854,7 +1847,7 @@ def read( _write_consolidated_metadata(file_path) - return read_zarr(file_path, selection=selection) + return read_zarr(file_path, selection=selection, lazy=lazy) @property def images(self) -> Images: diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 719c9c1a..8bb1de0d 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -7,18 +7,45 @@ from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format -from spatialdata._io.format import ( - CurrentTablesFormat, - TablesFormats, - TablesFormatV01, - TablesFormatV02, - _parse_version, -) +from spatialdata._io.format import CurrentTablesFormat, TablesFormats, TablesFormatV01, TablesFormatV02, _parse_version from spatialdata.models import TableModel, get_table_keys -def _read_table(store: str | Path) -> AnnData: - table = read_anndata_zarr(str(store)) +def _read_table(store: str | Path, lazy: bool = False) -> AnnData: + """ + Read a table from a zarr store. + + Parameters + ---------- + store + Path to the zarr store containing the table. + lazy + If True, read the table lazily using anndata.experimental.read_lazy. + This requires anndata >= 0.12. If the installed version does not support + lazy reading, a warning is raised and the table is read eagerly. + + Returns + ------- + The AnnData table, either lazily loaded or in-memory. + """ + if lazy: + try: + from anndata.experimental import read_lazy + + table = read_lazy(str(store)) + except ImportError: + import warnings + + warnings.warn( + "Lazy reading of tables requires anndata >= 0.12. " + "Falling back to eager reading. To enable lazy reading, " + "upgrade anndata with: pip install 'anndata>=0.12'", + UserWarning, + stacklevel=2, + ) + table = read_anndata_zarr(str(store)) + else: + table = read_anndata_zarr(str(store)) f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=False) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0312dc96..1cb73388 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,6 +1,7 @@ import os import warnings from collections.abc import Callable +from functools import partial from json import JSONDecodeError from pathlib import Path from typing import Any, Literal, cast @@ -15,11 +16,7 @@ from zarr.errors import ArrayNotFoundError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import ( - BadFileHandleMethod, - _resolve_zarr_store, - handle_read_errors, -) +from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -104,10 +101,7 @@ def get_raster_format_for_read( ------- The ome-zarr format to use for reading the raster element. """ - from spatialdata._io.format import ( - sdata_zarr_version_to_ome_zarr_format, - sdata_zarr_version_to_raster_format, - ) + from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format, sdata_zarr_version_to_raster_format if sdata_version == "0.1": group_version = group.metadata.attributes["multiscales"][0]["version"] @@ -124,6 +118,7 @@ def read_zarr( store: str | Path | UPath | zarr.Group, selection: None | tuple[str] = None, on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData dataset from a zarr store (on-disk or remote). @@ -147,6 +142,12 @@ def read_zarr( object is returned containing only elements that could be read. Failures can only be determined from the warnings. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. + Returns ------- A SpatialData object. @@ -193,7 +194,7 @@ def read_zarr( "labels": (_read_multiscale, "labels", labels), "points": (_read_points, "points", points), "shapes": (_read_shapes, "shapes", shapes), - "tables": (_read_table, "tables", tables), + "tables": (partial(_read_table, lazy=lazy), "tables", tables), } for group_name, ( read_func, diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index e834ad78..36a3a973 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -23,12 +23,7 @@ from shapely.io import from_geojson, from_ragged_array from spatial_image import to_spatial_image from xarray import DataArray, DataTree -from xarray_schema.components import ( - ArrayTypeSchema, - AttrSchema, - AttrsSchema, - DimsSchema, -) +from xarray_schema.components import ArrayTypeSchema, AttrSchema, AttrsSchema, DimsSchema from xarray_schema.dataarray import DataArraySchema from spatialdata._core.validation import validate_table_attr_keys @@ -45,11 +40,7 @@ _validate_mapping_to_coordinate_system_type, convert_region_column_to_categorical, ) -from spatialdata.transformations._utils import ( - _get_transformations, - _set_transformations, - compute_coordinates, -) +from spatialdata.transformations._utils import _get_transformations, _set_transformations, compute_coordinates from spatialdata.transformations.transformations import BaseTransformation, Identity # Types @@ -60,6 +51,25 @@ ATTRS_KEY = "spatialdata_attrs" +def _is_lazy_anndata(adata: AnnData) -> bool: + """Check if an AnnData object is lazily loaded. + + Lazy AnnData objects (from anndata.experimental.read_lazy) have obs/var + stored as xarray Dataset2D instead of pandas DataFrame. + + Parameters + ---------- + adata + The AnnData object to check. + + Returns + ------- + True if the AnnData is lazily loaded, False otherwise. + """ + # Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D) + return not isinstance(adata.obs, pd.DataFrame) + + def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: _validate_mapping_to_coordinate_system_type(transformations) transformations_in_element = _get_transformations(element) @@ -1036,6 +1046,13 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") if attr[self.INSTANCE_KEY] not in data.obs: raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") + + # Skip detailed dtype/value validation for lazy-loaded AnnData + # These checks would trigger data loading, defeating the purpose of lazy loading + # Validation will occur when data is actually computed/accessed + if _is_lazy_anndata(data): + return + if ( (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) not in [ @@ -1064,7 +1081,10 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: # Warning for object/string columns with NaN in region_key or instance_key instance_key = attr[self.INSTANCE_KEY] region_key = attr[self.REGION_KEY_KEY] - for key_name, key_value in [("region_key", region_key), ("instance_key", instance_key)]: + for key_name, key_value in [ + ("region_key", region_key), + ("instance_key", instance_key), + ]: if key_value in data.obs: col = data.obs[key_value] col_dtype = col.dtype @@ -1099,6 +1119,10 @@ def validate( if ATTRS_KEY not in data.uns: return data + # Check if this is a lazy-loaded AnnData (from anndata.experimental.read_lazy) + # Lazy AnnData has xarray-based obs/var, which requires different validation + is_lazy = _is_lazy_anndata(data) + _, region_key, instance_key = get_table_keys(data) if region_key is not None: if region_key not in data.obs: @@ -1106,7 +1130,8 @@ def validate( f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse " f"using TableModel.parse(adata)." ) - if not isinstance(data.obs[region_key].dtype, CategoricalDtype): + # Skip dtype validation for lazy tables (would require loading data) + if not is_lazy and not isinstance(data.obs[region_key].dtype, CategoricalDtype): raise ValueError( f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." ) @@ -1116,7 +1141,8 @@ def validate( f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse" f" using TableModel.parse(adata)." ) - if data.obs[instance_key].isnull().values.any(): + # Skip null check for lazy tables (would require loading data) + if not is_lazy and data.obs[instance_key].isnull().values.any(): raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") self._validate_table_annotation_metadata(data) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index af028d29..4973a4d7 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -94,7 +94,11 @@ def test_shapes( # add a mixed Polygon + MultiPolygon element shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) - shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding) + shapes.write( + tmpdir, + sdata_formats=sdata_container_format, + shapes_geometry_encoding=geometry_encoding, + ) sdata = SpatialData.read(tmpdir) if geometry_encoding == "WKB": @@ -1107,3 +1111,87 @@ def test_sdata_with_nan_in_obs() -> None: # After round-trip, NaN in object-dtype column becomes string "nan" assert sdata2["table"].obs["column_only_region1"].iloc[1] == "nan" assert np.isnan(sdata2["table"].obs["column_only_region2"].iloc[0]) + + +class TestLazyTableLoading: + """Tests for lazy table loading functionality. + + Lazy loading uses anndata.experimental.read_lazy() to keep large tables + out of memory until needed. This is particularly useful for MSI data + where tables can contain millions of pixels. + """ + + @pytest.fixture + def sdata_with_table(self) -> SpatialData: + """Create a SpatialData object with a simple table for testing.""" + from spatialdata.models import TableModel + + rng = default_rng(42) + table = TableModel.parse( + AnnData( + X=rng.random((100, 50)), + obs=pd.DataFrame( + { + "region": pd.Categorical(["region1"] * 100), + "instance": np.arange(100), + } + ), + ), + region_key="region", + instance_key="instance", + region="region1", + ) + return SpatialData(tables={"test_table": table}) + + def test_lazy_read_basic(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=True reads tables without loading into memory.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=True + try: + sdata_lazy = SpatialData.read(path, lazy=True) + + # Table should be present + assert "test_table" in sdata_lazy.tables + + # Check that X is a lazy array (dask or similar) + # Lazy AnnData from read_lazy uses dask arrays + table = sdata_lazy.tables["test_table"] + assert hasattr(table, "X") + + except ImportError: + # If anndata.experimental.read_lazy is not available, skip + pytest.skip("anndata.experimental.read_lazy not available") + + def test_lazy_false_loads_normally(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=False (default) loads tables into memory normally.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=False (default) + sdata_normal = SpatialData.read(path, lazy=False) + + # Table should be present and loaded normally + assert "test_table" in sdata_normal.tables + table = sdata_normal.tables["test_table"] + + # X should be a numpy array or scipy sparse matrix (in-memory) + import scipy.sparse as sp + + assert isinstance(table.X, np.ndarray | sp.spmatrix) + + def test_read_zarr_lazy_parameter(self, sdata_with_table: SpatialData) -> None: + """Test that read_zarr function accepts lazy parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Test read_zarr directly with lazy parameter + try: + sdata = read_zarr(path, lazy=True) + assert "test_table" in sdata.tables + except ImportError: + pytest.skip("anndata.experimental.read_lazy not available")