Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 29 additions & 36 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@

if TYPE_CHECKING:
from spatialdata._core.query.spatial_query import BaseSpatialRequest
from spatialdata._io.format import (
SpatialDataContainerFormatType,
SpatialDataFormatType,
)
from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType

# schema for elements
Label2D_s = Labels2DModel()
Expand Down Expand Up @@ -241,9 +238,7 @@ def get_annotated_regions(table: AnnData) -> list[str]:
-------
The annotated regions.
"""
from spatialdata.models.models import (
_get_region_metadata_from_region_key_column,
)
from spatialdata.models.models import _get_region_metadata_from_region_key_column

return _get_region_metadata_from_region_key_column(table)

Expand Down Expand Up @@ -695,18 +690,14 @@ def _filter_tables(
continue
# each mode here requires paths or elements, using assert here to avoid mypy errors.
if by == "cs":
from spatialdata._core.query.relational_query import (
_filter_table_by_element_names,
)
from spatialdata._core.query.relational_query import _filter_table_by_element_names

assert element_names is not None
table = _filter_table_by_element_names(table, element_names)
if table is not None and len(table) != 0:
tables[table_name] = table
elif by == "elements":
from spatialdata._core.query.relational_query import (
_filter_table_by_elements,
)
from spatialdata._core.query.relational_query import _filter_table_by_elements

assert elements_dict is not None
table = _filter_table_by_elements(table, elements_dict=elements_dict)
Expand All @@ -731,10 +722,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None:
The method does not allow to rename a coordinate system into an existing one, unless the existing one is also
renamed in the same call.
"""
from spatialdata.transformations.operations import (
get_transformation,
set_transformation,
)
from spatialdata.transformations.operations import get_transformation, set_transformation

# check that the rename_dict is valid
old_names = self.coordinate_systems
Expand Down Expand Up @@ -1110,7 +1098,7 @@ def write(
overwrite: bool = False,
consolidate_metadata: bool = True,
update_sdata_path: bool = True,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None,
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
) -> None:
"""
Expand Down Expand Up @@ -1215,15 +1203,12 @@ def _write_element(
)

root_group, element_type_group, element_group = _get_groups_for_element(
zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False
)
from spatialdata._io import (
write_image,
write_labels,
write_points,
write_shapes,
write_table,
zarr_path=zarr_container_path,
element_type=element_type,
element_name=element_name,
use_consolidated=False,
)
from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table
from spatialdata._io.format import _parse_formats

if parsed_formats is None:
Expand Down Expand Up @@ -1270,7 +1255,7 @@ def write_element(
self,
element_name: str | list[str],
overwrite: bool = False,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None,
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
) -> None:
"""
Expand Down Expand Up @@ -1548,7 +1533,10 @@ def write_channel_names(self, element_name: str | None = None) -> None:
# Mypy does not understand that path is not None so we have the check in the conditional
if element_type == "images" and self.path is not None:
_, _, element_group = _get_groups_for_element(
zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False
zarr_path=Path(self.path),
element_type=element_type,
element_name=element_name,
use_consolidated=False,
)

from spatialdata._io._utils import overwrite_channel_names
Expand Down Expand Up @@ -1599,19 +1587,18 @@ def write_transformations(self, element_name: str | None = None) -> None:
)
axes = get_axes_names(element)
if isinstance(element, DataArray | DataTree):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_raster,
)
from spatialdata._io._utils import overwrite_coordinate_transformations_raster
from spatialdata._io.format import RasterFormats

raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]]
overwrite_coordinate_transformations_raster(
group=element_group, axes=axes, transformations=transformations, raster_format=raster_format
group=element_group,
axes=axes,
transformations=transformations,
raster_format=raster_format,
)
elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_non_raster,
)
from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster

overwrite_coordinate_transformations_non_raster(
group=element_group,
Expand Down Expand Up @@ -1830,6 +1817,7 @@ def read(
file_path: str | Path | UPath | zarr.Group,
selection: tuple[str] | None = None,
reconsolidate_metadata: bool = False,
lazy: bool = False,
) -> SpatialData:
"""
Read a SpatialData object from a Zarr storage (on-disk or remote).
Expand All @@ -1842,6 +1830,11 @@ def read(
The elements to read (images, labels, points, shapes, table). If None, all elements are read.
reconsolidate_metadata
If the consolidated metadata store got corrupted this can lead to errors when trying to read the data.
lazy
If True, read tables lazily using anndata.experimental.read_lazy.
This keeps large tables out of memory until needed. Requires anndata >= 0.12.
Note: Images, labels, and points are always read lazily (using Dask).
This parameter only affects tables, which are normally loaded into memory.

Returns
-------
Expand All @@ -1854,7 +1847,7 @@ def read(

_write_consolidated_metadata(file_path)

return read_zarr(file_path, selection=selection)
return read_zarr(file_path, selection=selection, lazy=lazy)

@property
def images(self) -> Images:
Expand Down
45 changes: 36 additions & 9 deletions src/spatialdata/_io/io_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,45 @@
from anndata._io.specs import write_elem as write_adata
from ome_zarr.format import Format

from spatialdata._io.format import (
CurrentTablesFormat,
TablesFormats,
TablesFormatV01,
TablesFormatV02,
_parse_version,
)
from spatialdata._io.format import CurrentTablesFormat, TablesFormats, TablesFormatV01, TablesFormatV02, _parse_version
from spatialdata.models import TableModel, get_table_keys


def _read_table(store: str | Path) -> AnnData:
table = read_anndata_zarr(str(store))
def _read_table(store: str | Path, lazy: bool = False) -> AnnData:
"""
Read a table from a zarr store.

Parameters
----------
store
Path to the zarr store containing the table.
lazy
If True, read the table lazily using anndata.experimental.read_lazy.
This requires anndata >= 0.12. If the installed version does not support
lazy reading, a warning is raised and the table is read eagerly.

Returns
-------
The AnnData table, either lazily loaded or in-memory.
"""
if lazy:
try:
from anndata.experimental import read_lazy

table = read_lazy(str(store))
except ImportError:
import warnings

warnings.warn(
"Lazy reading of tables requires anndata >= 0.12. "
"Falling back to eager reading. To enable lazy reading, "
"upgrade anndata with: pip install 'anndata>=0.12'",
UserWarning,
stacklevel=2,
)
table = read_anndata_zarr(str(store))
else:
table = read_anndata_zarr(str(store))

f = zarr.open(store, mode="r")
version = _parse_version(f, expect_attrs_key=False)
Expand Down
21 changes: 11 additions & 10 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import warnings
from collections.abc import Callable
from functools import partial
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Literal, cast
Expand All @@ -15,11 +16,7 @@
from zarr.errors import ArrayNotFoundError

from spatialdata._core.spatialdata import SpatialData
from spatialdata._io._utils import (
BadFileHandleMethod,
_resolve_zarr_store,
handle_read_errors,
)
from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors
from spatialdata._io.io_points import _read_points
from spatialdata._io.io_raster import _read_multiscale
from spatialdata._io.io_shapes import _read_shapes
Expand Down Expand Up @@ -104,10 +101,7 @@ def get_raster_format_for_read(
-------
The ome-zarr format to use for reading the raster element.
"""
from spatialdata._io.format import (
sdata_zarr_version_to_ome_zarr_format,
sdata_zarr_version_to_raster_format,
)
from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format, sdata_zarr_version_to_raster_format

if sdata_version == "0.1":
group_version = group.metadata.attributes["multiscales"][0]["version"]
Expand All @@ -124,6 +118,7 @@ def read_zarr(
store: str | Path | UPath | zarr.Group,
selection: None | tuple[str] = None,
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR,
lazy: bool = False,
) -> SpatialData:
"""
Read a SpatialData dataset from a zarr store (on-disk or remote).
Expand All @@ -147,6 +142,12 @@ def read_zarr(
object is returned containing only elements that could be read. Failures can only be
determined from the warnings.

lazy
If True, read tables lazily using anndata.experimental.read_lazy.
This keeps large tables out of memory until needed. Requires anndata >= 0.12.
Note: Images, labels, and points are always read lazily (using Dask).
This parameter only affects tables, which are normally loaded into memory.

Returns
-------
A SpatialData object.
Expand Down Expand Up @@ -193,7 +194,7 @@ def read_zarr(
"labels": (_read_multiscale, "labels", labels),
"points": (_read_points, "points", points),
"shapes": (_read_shapes, "shapes", shapes),
"tables": (_read_table, "tables", tables),
"tables": (partial(_read_table, lazy=lazy), "tables", tables),
}
for group_name, (
read_func,
Expand Down
54 changes: 40 additions & 14 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
from shapely.io import from_geojson, from_ragged_array
from spatial_image import to_spatial_image
from xarray import DataArray, DataTree
from xarray_schema.components import (
ArrayTypeSchema,
AttrSchema,
AttrsSchema,
DimsSchema,
)
from xarray_schema.components import ArrayTypeSchema, AttrSchema, AttrsSchema, DimsSchema
from xarray_schema.dataarray import DataArraySchema

from spatialdata._core.validation import validate_table_attr_keys
Expand All @@ -45,11 +40,7 @@
_validate_mapping_to_coordinate_system_type,
convert_region_column_to_categorical,
)
from spatialdata.transformations._utils import (
_get_transformations,
_set_transformations,
compute_coordinates,
)
from spatialdata.transformations._utils import _get_transformations, _set_transformations, compute_coordinates
from spatialdata.transformations.transformations import BaseTransformation, Identity

# Types
Expand All @@ -60,6 +51,25 @@
ATTRS_KEY = "spatialdata_attrs"


def _is_lazy_anndata(adata: AnnData) -> bool:
"""Check if an AnnData object is lazily loaded.

Lazy AnnData objects (from anndata.experimental.read_lazy) have obs/var
stored as xarray Dataset2D instead of pandas DataFrame.

Parameters
----------
adata
The AnnData object to check.

Returns
-------
True if the AnnData is lazily loaded, False otherwise.
"""
# Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D)
return not isinstance(adata.obs, pd.DataFrame)


def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None:
_validate_mapping_to_coordinate_system_type(transformations)
transformations_in_element = _get_transformations(element)
Expand Down Expand Up @@ -1036,6 +1046,13 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None:
raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.")
if attr[self.INSTANCE_KEY] not in data.obs:
raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.")

# Skip detailed dtype/value validation for lazy-loaded AnnData
# These checks would trigger data loading, defeating the purpose of lazy loading
# Validation will occur when data is actually computed/accessed
if _is_lazy_anndata(data):
return

if (
(dtype := data.obs[attr[self.INSTANCE_KEY]].dtype)
not in [
Expand Down Expand Up @@ -1064,7 +1081,10 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None:
# Warning for object/string columns with NaN in region_key or instance_key
instance_key = attr[self.INSTANCE_KEY]
region_key = attr[self.REGION_KEY_KEY]
for key_name, key_value in [("region_key", region_key), ("instance_key", instance_key)]:
for key_name, key_value in [
("region_key", region_key),
("instance_key", instance_key),
]:
if key_value in data.obs:
col = data.obs[key_value]
col_dtype = col.dtype
Expand Down Expand Up @@ -1099,14 +1119,19 @@ def validate(
if ATTRS_KEY not in data.uns:
return data

# Check if this is a lazy-loaded AnnData (from anndata.experimental.read_lazy)
# Lazy AnnData has xarray-based obs/var, which requires different validation
is_lazy = _is_lazy_anndata(data)

_, region_key, instance_key = get_table_keys(data)
if region_key is not None:
if region_key not in data.obs:
raise ValueError(
f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse "
f"using TableModel.parse(adata)."
)
if not isinstance(data.obs[region_key].dtype, CategoricalDtype):
# Skip dtype validation for lazy tables (would require loading data)
if not is_lazy and not isinstance(data.obs[region_key].dtype, CategoricalDtype):
raise ValueError(
f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`."
)
Expand All @@ -1116,7 +1141,8 @@ def validate(
f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse"
f" using TableModel.parse(adata)."
)
if data.obs[instance_key].isnull().values.any():
# Skip null check for lazy tables (would require loading data)
if not is_lazy and data.obs[instance_key].isnull().values.any():
raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.")

self._validate_table_annotation_metadata(data)
Expand Down
Loading
Loading