From 6821df4504d4417a79e3c5b67e891068a9e9a1dd Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:58:07 +0100 Subject: [PATCH 1/3] feat: add lazy table loading via anndata.experimental.read_lazy Add a `lazy` parameter to `SpatialData.read()` and `read_zarr()` that enables lazy loading of tables using anndata's experimental `read_lazy()` function. This is particularly useful for large datasets (e.g., Mass Spectrometry Imaging with millions of pixels) where loading tables into memory is not feasible. Changes: - Add `lazy: bool = False` parameter to `read_zarr()` in io_zarr.py - Add `lazy: bool = False` parameter to `_read_table()` in io_table.py - Add `lazy: bool = False` parameter to `SpatialData.read()` in spatialdata.py - Add `_is_lazy_anndata()` helper to detect lazy AnnData objects - Skip eager validation for lazy tables to preserve lazy loading benefits - Add tests for lazy loading functionality Requires anndata >= 0.12 for lazy loading support. Falls back to eager loading with a warning if anndata version does not support read_lazy. --- src/spatialdata/_core/spatialdata.py | 401 +++++++++++++++++---------- src/spatialdata/_io/io_table.py | 45 ++- src/spatialdata/_io/io_zarr.py | 59 ++-- src/spatialdata/models/models.py | 238 ++++++++++++---- tests/io/test_readwrite.py | 254 ++++++++++++++--- 5 files changed, 726 insertions(+), 271 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c06e62b7..869683ba 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -22,40 +22,19 @@ from zarr.errors import GroupNotFoundError from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables -from spatialdata._core.validation import ( - check_all_keys_case_insensitively_unique, - check_target_region_column_symmetry, - check_valid_name, - raise_validation_errors, - validate_table_attr_keys, -) +from spatialdata._core.validation import (check_all_keys_case_insensitively_unique, check_target_region_column_symmetry, + check_valid_name, raise_validation_errors, validate_table_attr_keys) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import _deprecation_alias -from spatialdata.models import ( - Image2DModel, - Image3DModel, - Labels2DModel, - Labels3DModel, - PointsModel, - ShapesModel, - TableModel, - get_model, - get_table_keys, -) -from spatialdata.models._utils import ( - SpatialElement, - convert_region_column_to_categorical, - get_axes_names, - set_channel_names, -) +from spatialdata.models import (Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel, + TableModel, get_model, get_table_keys) +from spatialdata.models._utils import (SpatialElement, convert_region_column_to_categorical, get_axes_names, + set_channel_names) if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import ( - SpatialDataContainerFormatType, - SpatialDataFormatType, - ) + from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType # schema for elements Label2D_s = Labels2DModel() @@ -140,7 +119,11 @@ def __init__( self._tables: Tables = Tables(shared_keys=self._shared_keys) self.attrs = attrs if attrs else {} # type: ignore[assignment] - element_names = list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) + element_names = list( + chain.from_iterable( + [e.keys() for e in [images, labels, points, shapes] if e is not None] + ) + ) if len(element_names) != len(set(element_names)): duplicates = {x for x in element_names if element_names.count(x) > 1} @@ -241,9 +224,7 @@ def get_annotated_regions(table: AnnData) -> list[str]: ------- The annotated regions. """ - from spatialdata.models.models import ( - _get_region_metadata_from_region_key_column, - ) + from spatialdata.models.models import _get_region_metadata_from_region_key_column return _get_region_metadata_from_region_key_column(table) @@ -268,7 +249,9 @@ def get_region_key_column(table: AnnData) -> pd.Series: _, region_key, _ = get_table_keys(table) if table.obs.get(region_key) is not None: return table.obs[region_key] - raise KeyError(f"{region_key} is set as region key column. However the column is not found in table.obs.") + raise KeyError( + f"{region_key} is set as region key column. However the column is not found in table.obs." + ) @staticmethod def get_instance_key_column(table: AnnData) -> pd.Series: @@ -293,9 +276,13 @@ def get_instance_key_column(table: AnnData) -> pd.Series: _, _, instance_key = get_table_keys(table) if table.obs.get(instance_key) is not None: return table.obs[instance_key] - raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") + raise KeyError( + f"{instance_key} is set as instance key column. However the column is not found in table.obs." + ) - def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: + def set_channel_names( + self, element_name: str, channel_names: str | list[str], write: bool = False + ) -> None: """Set the channel names for an image `SpatialElement` in the `SpatialData` object. This method will overwrite the element in memory with the same element, but with new channel names. @@ -313,7 +300,9 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w Whether to overwrite the channel metadata on disk (lightweight operation). This will not rewrite the pixel data itself (heavy operation). """ - self.images[element_name] = set_channel_names(self.images[element_name], channel_names) + self.images[element_name] = set_channel_names( + self.images[element_name], channel_names + ) if write: self.write_channel_names(element_name) @@ -391,7 +380,9 @@ def _change_table_annotation_target( If provided region_key is not present in table.obs. """ attrs = table.uns[TableModel.ATTRS_KEY] - table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) + table_region_key = ( + region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) + ) TableModel()._validate_set_region_key(table, region_key) TableModel()._validate_set_instance_key(table, instance_key) @@ -399,7 +390,9 @@ def _change_table_annotation_target( attrs[TableModel.REGION_KEY] = region @staticmethod - def update_annotated_regions_metadata(table: AnnData, region_key: str | None = None) -> AnnData: + def update_annotated_regions_metadata( + table: AnnData, region_key: str | None = None + ) -> AnnData: """ Update the annotation target of the table using the region_key column in table.obs. @@ -422,7 +415,9 @@ def update_annotated_regions_metadata(table: AnnData, region_key: str | None = N """ attrs = table.uns.get(TableModel.ATTRS_KEY) if attrs is None: - raise ValueError("The table has no annotation metadata. Please parse the table using `TableModel.parse`.") + raise ValueError( + "The table has no annotation metadata. Please parse the table using `TableModel.parse`." + ) region_key = region_key if region_key else attrs[TableModel.REGION_KEY_KEY] if attrs[TableModel.REGION_KEY_KEY] != region_key: attrs[TableModel.REGION_KEY_KEY] = region_key @@ -470,14 +465,20 @@ def set_table_annotates_spatialelement( isinstance(region, list | pd.Series) and not all(region_element in element_names for region_element in region) ): - raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.") + raise ValueError( + f"Annotation target '{region}' not present as SpatialElement in SpatialData object." + ) if table.uns.get(TableModel.ATTRS_KEY): - self._change_table_annotation_target(table, region, region_key, instance_key) + self._change_table_annotation_target( + table, region, region_key, instance_key + ) elif isinstance(region_key, str) and isinstance(instance_key, str): self._set_table_annotation_target(table, region, region_key, instance_key) else: - raise TypeError("No current annotation metadata found. Please specify both region_key and instance_key.") + raise TypeError( + "No current annotation metadata found. Please specify both region_key and instance_key." + ) convert_region_column_to_categorical(table) @property @@ -588,9 +589,17 @@ def locate_element(self, element: SpatialElement) -> list[str]: found_element_name.append(element_name) if len(found) == 0: return [] - if any("/" in found_element_name[i] or "/" in found_element_type[i] for i in range(len(found))): - raise ValueError("Found an element name with a '/' character. This is not allowed.") - return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] + if any( + "/" in found_element_name[i] or "/" in found_element_type[i] + for i in range(len(found)) + ): + raise ValueError( + "Found an element name with a '/' character. This is not allowed." + ) + return [ + f"{found_element_type[i]}/{found_element_name[i]}" + for i in range(len(found)) + ] def filter_by_coordinate_system( self, @@ -688,28 +697,28 @@ def _filter_tables( if include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): tables[table_name] = table continue - if not include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): + if not include_orphan_tables and not table.uns.get( + TableModel.ATTRS_KEY + ): continue if table_name in names_tables_to_keep: tables[table_name] = table continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import ( - _filter_table_by_element_names, - ) + from spatialdata._core.query.relational_query import _filter_table_by_element_names assert element_names is not None table = _filter_table_by_element_names(table, element_names) if table is not None and len(table) != 0: tables[table_name] = table elif by == "elements": - from spatialdata._core.query.relational_query import ( - _filter_table_by_elements, - ) + from spatialdata._core.query.relational_query import _filter_table_by_elements assert elements_dict is not None - table = _filter_table_by_elements(table, elements_dict=elements_dict) + table = _filter_table_by_elements( + table, elements_dict=elements_dict + ) if table is not None and len(table) != 0: tables[table_name] = table else: @@ -731,10 +740,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: The method does not allow to rename a coordinate system into an existing one, unless the existing one is also renamed in the same call. """ - from spatialdata.transformations.operations import ( - get_transformation, - set_transformation, - ) + from spatialdata.transformations.operations import get_transformation, set_transformation # check that the rename_dict is valid old_names = self.coordinate_systems @@ -760,7 +766,9 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: for old_cs, new_cs in rename_dict.items(): if old_cs in transformations: random_suffix = hashlib.sha1(os.urandom(128)).hexdigest()[:8] - transformations[new_cs + random_suffix] = transformations.pop(old_cs) + transformations[new_cs + random_suffix] = transformations.pop( + old_cs + ) suffixes_to_replace.add(new_cs + random_suffix) # remove the random suffixes @@ -771,10 +779,14 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: new_transformations[cs] = transformations[cs_with_suffix] suffixes_to_replace.remove(cs_with_suffix) else: - new_transformations[cs_with_suffix] = transformations[cs_with_suffix] + new_transformations[cs_with_suffix] = transformations[ + cs_with_suffix + ] # set the new transformations - set_transformation(element=element, transformation=new_transformations, set_all=True) + set_transformation( + element=element, transformation=new_transformations, set_all=True + ) def transform_element_to_coordinate_system( self, @@ -802,17 +814,18 @@ def transform_element_to_coordinate_system( """ from spatialdata import transform from spatialdata.transformations import Sequence - from spatialdata.transformations.operations import ( - get_transformation, - get_transformation_between_coordinate_systems, - remove_transformation, - set_transformation, - ) + from spatialdata.transformations.operations import (get_transformation, + get_transformation_between_coordinate_systems, + remove_transformation, set_transformation) element = self.get(element_name) - t = get_transformation_between_coordinate_systems(self, element, target_coordinate_system) + t = get_transformation_between_coordinate_systems( + self, element, target_coordinate_system + ) if maintain_positioning: - transformed = transform(element, transformation=t, maintain_positioning=maintain_positioning) + transformed = transform( + element, transformation=t, maintain_positioning=maintain_positioning + ) else: d = get_transformation(element, get_all=True) assert isinstance(d, dict) @@ -848,7 +861,9 @@ def transform_element_to_coordinate_system( # since target_coordinate_system is in d, we have that t is a Sequence with only one transformation. assert isinstance(t, Sequence) assert len(t.transformations) == 1 - seq = get_transformation(transformed, to_coordinate_system=target_coordinate_system) + seq = get_transformation( + transformed, to_coordinate_system=target_coordinate_system + ) assert isinstance(seq, Sequence) assert len(seq.transformations) == 2 assert seq.transformations[1] is t.transformations[0] @@ -877,7 +892,9 @@ def transform_to_coordinate_system( ------- The transformed SpatialData. """ - sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) + sdata = self.filter_by_coordinate_system( + target_coordinate_system, filter_tables=False + ) elements: dict[str, dict[str, SpatialElement]] = {} for element_type, element_name, _ in sdata.gen_elements(): if element_type != "tables": @@ -913,7 +930,9 @@ def elements_are_self_contained(self) -> dict[str, bool]: description = {} for element_type, element_name, element in self.gen_elements(): element_path = self.path / element_type / element_name - description[element_name] = _is_element_self_contained(element, element_path) + description[element_name] = _is_element_self_contained( + element, element_path + ) return description def is_self_contained(self, element_name: str | None = None) -> bool: @@ -1030,8 +1049,12 @@ def _symmetric_difference_with_zarr_store(self) -> tuple[list[str], list[str]]: elements_in_sdata = self.elements_paths_in_memory() elements_in_zarr = self.elements_paths_on_disk() - elements_only_in_sdata = list(set(elements_in_sdata).difference(set(elements_in_zarr))) - elements_only_in_zarr = list(set(elements_in_zarr).difference(set(elements_in_sdata))) + elements_only_in_sdata = list( + set(elements_in_sdata).difference(set(elements_in_zarr)) + ) + elements_only_in_zarr = list( + set(elements_in_zarr).difference(set(elements_in_sdata)) + ) return elements_only_in_sdata, elements_only_in_zarr def _validate_can_safely_write_to_path( @@ -1046,7 +1069,9 @@ def _validate_can_safely_write_to_path( file_path = Path(file_path) if not isinstance(file_path, Path): - raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") + raise ValueError( + f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}." + ) # TODO: add test for this if os.path.exists(file_path): @@ -1068,24 +1093,28 @@ def _validate_can_safely_write_to_path( "Cannot overwrite. The target path of the write operation is in use. Please save the data to a " "different location. " ) - WORKAROUND = ( - "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." - ) + WORKAROUND = "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." if any(_backed_elements_contained_in_path(path=file_path, object=self)): raise ValueError( - ERROR_MSG + "\nDetails: the target path contains one or more files that Dask use for " + ERROR_MSG + + "\nDetails: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object." + WORKAROUND ) if self.path is not None and ( - _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) + _is_subfolder(parent=self.path, child=file_path) + or _is_subfolder(parent=file_path, child=self.path) ): - if saving_an_element and _is_subfolder(parent=self.path, child=file_path): + if saving_an_element and _is_subfolder( + parent=self.path, child=file_path + ): raise ValueError( - ERROR_MSG + "\nDetails: the target path in which to save an element is a subfolder " + ERROR_MSG + + "\nDetails: the target path in which to save an element is a subfolder " "of the current Zarr store." + WORKAROUND ) raise ValueError( - ERROR_MSG + "\nDetails: the target path either contains, coincides or is contained in" + ERROR_MSG + + "\nDetails: the target path either contains, coincides or is contained in" " the current Zarr store." + WORKAROUND ) @@ -1110,7 +1139,9 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: ( + SpatialDataFormatType | list[SpatialDataFormatType] | None + ) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1172,7 +1203,9 @@ def write( store = _resolve_zarr_store(file_path) zarr_format = parsed["SpatialData"].zarr_format - zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) + zarr_group = zarr.create_group( + store=store, overwrite=overwrite, zarr_format=zarr_format + ) self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) store.close() @@ -1215,15 +1248,12 @@ def _write_element( ) root_group, element_type_group, element_group = _get_groups_for_element( - zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False - ) - from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, + zarr_path=zarr_container_path, + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table from spatialdata._io.format import _parse_formats if parsed_formats is None: @@ -1270,7 +1300,9 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: ( + SpatialDataFormatType | list[SpatialDataFormatType] | None + ) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1315,7 +1347,9 @@ def write_element( self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) if self.path is None: raise ValueError( @@ -1329,11 +1363,15 @@ def write_element( element_type = _element_type break if element_type is None: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) if element_type == "tables": validate_table_attr_keys(element) - self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) + self._check_element_not_on_disk_with_different_type( + element_type=element_type, element_name=element_name + ) self._write_element( element=element, @@ -1393,10 +1431,16 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: raise ValueError("The SpatialData object is not backed by a Zarr store.") on_disk = self.elements_paths_on_disk() - one_disk_names = [self._element_type_and_name_from_element_path(path)[1] for path in on_disk] + one_disk_names = [ + self._element_type_and_name_from_element_path(path)[1] for path in on_disk + ] in_memory = self.elements_paths_in_memory() - in_memory_names = [self._element_type_and_name_from_element_path(path)[1] for path in in_memory] - only_in_memory_names = list(set(in_memory_names).difference(set(one_disk_names))) + in_memory_names = [ + self._element_type_and_name_from_element_path(path)[1] for path in in_memory + ] + only_in_memory_names = list( + set(in_memory_names).difference(set(one_disk_names)) + ) only_on_disk_names = list(set(one_disk_names).difference(set(in_memory_names))) ERROR_MESSAGE = f"Element {element_name} is not found in the Zarr store associated with the SpatialData object." @@ -1409,19 +1453,25 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if found: _element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type(element_type=_element_type, element_name=element_name) + self._check_element_not_on_disk_with_different_type( + element_type=_element_type, element_name=element_name + ) element_type = None on_disk = self.elements_paths_on_disk() for path in on_disk: - _element_type, _element_name = self._element_type_and_name_from_element_path(path) + _element_type, _element_name = ( + self._element_type_and_name_from_element_path(path) + ) if _element_name == element_name: element_type = _element_type break assert element_type is not None file_path_of_element = self.path / element_type / element_name - if any(_backed_elements_contained_in_path(path=file_path_of_element, object=self)): + if any( + _backed_elements_contained_in_path(path=file_path_of_element, object=self) + ): raise ValueError( "The file path specified is a parent directory of one or more files used for backing for one or " "more elements in the SpatialData object. Deleting the data would corrupt the SpatialData object." @@ -1438,10 +1488,14 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if self.has_consolidated_metadata(): self.write_consolidated_metadata() - def _check_element_not_on_disk_with_different_type(self, element_type: str, element_name: str) -> None: + def _check_element_not_on_disk_with_different_type( + self, element_type: str, element_name: str + ) -> None: only_on_disk = self.elements_paths_on_disk() for disk_path in only_on_disk: - disk_element_type, disk_element_name = self._element_type_and_name_from_element_path(disk_path) + disk_element_type, disk_element_name = ( + self._element_type_and_name_from_element_path(disk_path) + ) if disk_element_name == element_name and disk_element_type != element_type: raise ValueError( f"Element {element_name} is found in the Zarr store as a {disk_element_type}, but it is found " @@ -1466,7 +1520,9 @@ def has_consolidated_metadata(self) -> bool: store.close() return return_value - def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[str, SpatialElement | AnnData] | None: + def _validate_can_write_metadata_on_element( + self, element_name: str + ) -> tuple[str, SpatialElement | AnnData] | None: """Validate if metadata can be written on an element, returns None if it cannot be written.""" from spatialdata._io._utils import _is_element_self_contained from spatialdata._io.io_zarr import _group_for_element_exists @@ -1489,7 +1545,9 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) + self._check_element_not_on_disk_with_different_type( + element_type=element_type, element_name=element_name + ) # check if the element exists in the Zarr storage if not _group_for_element_exists( @@ -1508,7 +1566,9 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st # warn the users if the element is not self-contained, that is, it is Dask-backed by files outside the Zarr # group for the element element_zarr_path = Path(self.path) / element_type / element_name - if not _is_element_self_contained(element=element, element_path=element_zarr_path): + if not _is_element_self_contained( + element=element, element_path=element_zarr_path + ): logger.info( f"Element {element_type}/{element_name} is not self-contained. The metadata will be" " saved to the Zarr group of the element in the SpatialData Zarr store. The data outside the element " @@ -1531,7 +1591,9 @@ def write_channel_names(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1548,14 +1610,19 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) from spatialdata._io._utils import overwrite_channel_names overwrite_channel_names(element_group, element) else: - raise ValueError(f"Can't set channel names for element of type '{element_type}'.") + raise ValueError( + f"Can't set channel names for element of type '{element_type}'." + ) def write_transformations(self, element_name: str | None = None) -> None: """ @@ -1571,7 +1638,9 @@ def write_transformations(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1599,19 +1668,20 @@ def write_transformations(self, element_name: str | None = None) -> None: ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_raster from spatialdata._io.format import RasterFormats - raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] + raster_format = RasterFormats[ + element_group.metadata.attributes["spatialdata_attrs"]["version"] + ] overwrite_coordinate_transformations_raster( - group=element_group, axes=axes, transformations=transformations, raster_format=raster_format + group=element_group, + axes=axes, + transformations=transformations, + raster_format=raster_format, ) elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster overwrite_coordinate_transformations_non_raster( group=element_group, @@ -1625,7 +1695,9 @@ def _element_type_from_element_name(self, element_name: str) -> str: self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) located = self.locate_element(element) element_type = None @@ -1639,7 +1711,9 @@ def _element_type_from_element_name(self, element_name: str) -> str: assert element_type is not None return element_type - def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[str, str]: + def _element_type_and_name_from_element_path( + self, element_path: str + ) -> tuple[str, str]: element_type, element_name = element_path.split("/") return element_type, element_name @@ -1652,19 +1726,27 @@ def write_attrs( from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType - sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat() + sdata_format = ( + sdata_format + if sdata_format is not None + else CurrentSpatialDataContainerFormat() + ) assert isinstance(sdata_format, SpatialDataContainerFormatType) store = None if zarr_group is None: - assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." + assert ( + self.is_backed() + ), "The SpatialData object must be backed by a Zarr store to write attrs." store = _resolve_zarr_store(self.path) zarr_group = zarr.open_group(store=store, mode="r+") version = sdata_format.spatialdata_format_version version_specific_attrs = sdata_format.attrs_to_dict() - attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs + attrs_to_write = { + "spatialdata_attrs": {"version": version} | version_specific_attrs + } | self.attrs try: zarr_group.attrs.put(attrs_to_write) @@ -1713,7 +1795,9 @@ def write_metadata( if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) if write_attrs: self.write_attrs(sdata_format=sdata_format) @@ -1756,7 +1840,9 @@ def get_attrs( the value of `return_as`. """ - def _flatten_mapping(m: Mapping[str, Any], parent_key: str = "", sep: str = "_") -> dict[str, Any]: + def _flatten_mapping( + m: Mapping[str, Any], parent_key: str = "", sep: str = "_" + ) -> dict[str, Any]: items: list[tuple[str, Any]] = [] for k, v in m.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k @@ -1801,7 +1887,9 @@ def _flatten_mapping(m: Mapping[str, Any], parent_key: str = "", sep: str = "_") except Exception as e: raise ValueError(f"Failed to convert data to DataFrame: {e}") from e - raise ValueError(f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None.") + raise ValueError( + f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None." + ) @property def tables(self) -> Tables: @@ -1830,6 +1918,7 @@ def read( file_path: str | Path | UPath | zarr.Group, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1842,6 +1931,11 @@ def read( The elements to read (images, labels, points, shapes, table). If None, all elements are read. reconsolidate_metadata If the consolidated metadata store got corrupted this can lead to errors when trying to read the data. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. Returns ------- @@ -1854,7 +1948,7 @@ def read( _write_consolidated_metadata(file_path) - return read_zarr(file_path, selection=selection) + return read_zarr(file_path, selection=selection, lazy=lazy) @property def images(self) -> Images: @@ -1933,7 +2027,8 @@ def _non_empty_elements(self) -> list[str]: return [ element for element in all_elements - if (getattr(self, element) is not None) and (len(getattr(self, element)) > 0) + if (getattr(self, element) is not None) + and (len(getattr(self, element)) > 0) ] def __repr__(self) -> str: @@ -1971,7 +2066,9 @@ def h(s: str) -> str: descr += f"\n{h('level0')}{attr.capitalize()}" unsorted_elements = attribute.items() - sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) + sorted_elements = sorted( + unsorted_elements, key=lambda x: _natural_keys(x[0]) + ) for k, v in sorted_elements: descr += f"{h('empty_line')}" descr_class = v.__class__.__name__ @@ -2003,7 +2100,16 @@ def h(s: str) -> str: else: shape_str = ( "(" - + ", ".join([(str(dim) if not isinstance(dim, Scalar) else "") for dim in v.shape]) + + ", ".join( + [ + ( + str(dim) + if not isinstance(dim, Scalar) + else "" + ) + for dim in v.shape + ] + ) + ")" ) descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}" @@ -2079,7 +2185,9 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: if not self.is_self_contained(): assert self.path is not None - descr += "\nwith the following Dask-backed elements not being self-contained:" + descr += ( + "\nwith the following Dask-backed elements not being self-contained:" + ) description = self.elements_are_self_contained() for _, element_name, element in self.gen_elements(): if not description[element_name]: @@ -2087,7 +2195,9 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: descr += f"\n ▸ {element_name}: {backing_files}" if self.path is not None: - elements_only_in_sdata, elements_only_in_zarr = self._symmetric_difference_with_zarr_store() + elements_only_in_sdata, elements_only_in_zarr = ( + self._symmetric_difference_with_zarr_store() + ) if len(elements_only_in_sdata) > 0: descr += "\nwith the following elements not in the Zarr store:" for element_path in elements_only_in_sdata: @@ -2174,9 +2284,13 @@ def _validate_element_names_are_unique(self) -> None: ValueError If the element names are not unique. """ - check_all_keys_case_insensitively_unique([name for _, name, _ in self.gen_elements()], location=()) + check_all_keys_case_insensitively_unique( + [name for _, name, _ in self.gen_elements()], location=() + ) - def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]: + def _find_element( + self, element_name: str + ) -> tuple[str, str, SpatialElement | AnnData]: """ Retrieve SpatialElement or Table from the SpatialData instance matching element_name. @@ -2275,7 +2389,9 @@ def subset( """ elements_dict: dict[str, SpatialElement] = {} names_tables_to_keep: set[str] = set() - for element_type, element_name, element in self._gen_elements(include_tables=True): + for element_type, element_name, element in self._gen_elements( + include_tables=True + ): if element_name in element_names: if element_type != "tables": elements_dict.setdefault(element_type, {})[element_name] = element @@ -2308,11 +2424,16 @@ def __getitem__(self, item: str) -> SpatialElement | AnnData: def __contains__(self, key: str) -> bool: element_dict = { - element_name: element_value for _, element_name, element_value in self._gen_elements(include_tables=True) + element_name: element_value + for _, element_name, element_value in self._gen_elements( + include_tables=True + ) } return key in element_dict - def get(self, key: str, default_value: SpatialElement | AnnData | None = None) -> SpatialElement | AnnData | None: + def get( + self, key: str, default_value: SpatialElement | AnnData | None = None + ) -> SpatialElement | AnnData | None: """ Get element from SpatialData object based on corresponding name. @@ -2420,7 +2541,9 @@ def filter_by_table_query( obs_names_expr: Predicates | None = None, var_names_expr: Predicates | None = None, layer: str | None = None, - how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + how: Literal[ + "left", "left_exclusive", "inner", "right", "right_exclusive" + ] = "right", ) -> SpatialData: """ Filter the SpatialData object based on a set of table queries. diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 719c9c1a..8bb1de0d 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -7,18 +7,45 @@ from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format -from spatialdata._io.format import ( - CurrentTablesFormat, - TablesFormats, - TablesFormatV01, - TablesFormatV02, - _parse_version, -) +from spatialdata._io.format import CurrentTablesFormat, TablesFormats, TablesFormatV01, TablesFormatV02, _parse_version from spatialdata.models import TableModel, get_table_keys -def _read_table(store: str | Path) -> AnnData: - table = read_anndata_zarr(str(store)) +def _read_table(store: str | Path, lazy: bool = False) -> AnnData: + """ + Read a table from a zarr store. + + Parameters + ---------- + store + Path to the zarr store containing the table. + lazy + If True, read the table lazily using anndata.experimental.read_lazy. + This requires anndata >= 0.12. If the installed version does not support + lazy reading, a warning is raised and the table is read eagerly. + + Returns + ------- + The AnnData table, either lazily loaded or in-memory. + """ + if lazy: + try: + from anndata.experimental import read_lazy + + table = read_lazy(str(store)) + except ImportError: + import warnings + + warnings.warn( + "Lazy reading of tables requires anndata >= 0.12. " + "Falling back to eager reading. To enable lazy reading, " + "upgrade anndata with: pip install 'anndata>=0.12'", + UserWarning, + stacklevel=2, + ) + table = read_anndata_zarr(str(store)) + else: + table = read_anndata_zarr(str(store)) f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=False) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0312dc96..c24c257c 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,6 +1,7 @@ import os import warnings from collections.abc import Callable +from functools import partial from json import JSONDecodeError from pathlib import Path from typing import Any, Literal, cast @@ -15,11 +16,7 @@ from zarr.errors import ArrayNotFoundError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import ( - BadFileHandleMethod, - _resolve_zarr_store, - handle_read_errors, -) +from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -36,7 +33,12 @@ def _read_zarr_group_spatialdata_element( read_func: Callable[..., Any], group_name: Literal["images", "labels", "shapes", "points", "tables"], element_type: Literal["image", "labels", "shapes", "points", "tables"], - element_container: (dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData]), + element_container: ( + dict[str, Raster_T] + | dict[str, DaskDataFrame] + | dict[str, GeoDataFrame] + | dict[str, AnnData] + ), on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], ) -> None: with handle_read_errors( @@ -66,7 +68,9 @@ def _read_zarr_group_spatialdata_element( ), ): if element_type in ["image", "labels"]: - reader_format = get_raster_format_for_read(elem_group, sdata_version) + reader_format = get_raster_format_for_read( + elem_group, sdata_version + ) element = read_func( elem_group_path, cast(Literal["image", "labels"], element_type), @@ -104,10 +108,7 @@ def get_raster_format_for_read( ------- The ome-zarr format to use for reading the raster element. """ - from spatialdata._io.format import ( - sdata_zarr_version_to_ome_zarr_format, - sdata_zarr_version_to_raster_format, - ) + from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format, sdata_zarr_version_to_raster_format if sdata_version == "0.1": group_version = group.metadata.attributes["multiscales"][0]["version"] @@ -123,7 +124,10 @@ def get_raster_format_for_read( def read_zarr( store: str | Path | UPath | zarr.Group, selection: None | tuple[str] = None, - on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, + on_bad_files: Literal[ + BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN + ] = BadFileHandleMethod.ERROR, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData dataset from a zarr store (on-disk or remote). @@ -147,6 +151,12 @@ def read_zarr( object is returned containing only elements that could be read. Failures can only be determined from the warnings. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. + Returns ------- A SpatialData object. @@ -176,7 +186,11 @@ def read_zarr( shapes: dict[str, GeoDataFrame] = {} tables: dict[str, AnnData] = {} - selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) + selector = ( + {"images", "labels", "points", "shapes", "tables"} + if not selection + else set(selection or []) + ) logger.debug(f"Reading selection {selector}") # we could make this more readable. One can get lost when looking at this dict and iteration over the items @@ -185,7 +199,10 @@ def read_zarr( tuple[ Callable[..., Any], Literal["image", "labels", "shapes", "points", "tables"], - dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData], + dict[str, Raster_T] + | dict[str, DaskDataFrame] + | dict[str, GeoDataFrame] + | dict[str, AnnData], ], ] = { # ome-zarr-py needs a kwargs that has "image" has key. So here we have "image" and not "images" @@ -193,7 +210,7 @@ def read_zarr( "labels": (_read_multiscale, "labels", labels), "points": (_read_points, "points", points), "shapes": (_read_shapes, "shapes", shapes), - "tables": (_read_table, "tables", tables), + "tables": (partial(_read_table, lazy=lazy), "tables", tables), } for group_name, ( read_func, @@ -279,15 +296,21 @@ def _get_groups_for_element( # When writing, use_consolidated must be set to False. Otherwise, the metadata store # can get out of sync with newly added elements (e.g., labels), leading to errors. - root_group = zarr.open_group(store=resolved_store, mode="r+", use_consolidated=use_consolidated) + root_group = zarr.open_group( + store=resolved_store, mode="r+", use_consolidated=use_consolidated + ) element_type_group = root_group.require_group(element_type) - element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) + element_type_group = zarr.open_group( + element_type_group.store_path, mode="a", use_consolidated=use_consolidated + ) element_name_group = element_type_group.require_group(element_name) return root_group, element_type_group, element_name_group -def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: +def _group_for_element_exists( + zarr_path: Path, element_type: str, element_name: str +) -> bool: """ Check if the group for an element exists. diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index e834ad78..fa3920c2 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -23,12 +23,7 @@ from shapely.io import from_geojson, from_ragged_array from spatial_image import to_spatial_image from xarray import DataArray, DataTree -from xarray_schema.components import ( - ArrayTypeSchema, - AttrSchema, - AttrsSchema, - DimsSchema, -) +from xarray_schema.components import ArrayTypeSchema, AttrSchema, AttrsSchema, DimsSchema from xarray_schema.dataarray import DataArraySchema from spatialdata._core.validation import validate_table_attr_keys @@ -37,30 +32,49 @@ from spatialdata._utils import _check_match_length_channels_c_dim from spatialdata.config import settings from spatialdata.models import C, X, Y, Z, get_axes_names -from spatialdata.models._utils import ( - DEFAULT_COORDINATE_SYSTEM, - TRANSFORM_KEY, - MappingToCoordinateSystem_t, - SpatialElement, - _validate_mapping_to_coordinate_system_type, - convert_region_column_to_categorical, -) -from spatialdata.transformations._utils import ( - _get_transformations, - _set_transformations, - compute_coordinates, -) +from spatialdata.models._utils import (DEFAULT_COORDINATE_SYSTEM, TRANSFORM_KEY, MappingToCoordinateSystem_t, + SpatialElement, _validate_mapping_to_coordinate_system_type, + convert_region_column_to_categorical) +from spatialdata.transformations._utils import _get_transformations, _set_transformations, compute_coordinates from spatialdata.transformations.transformations import BaseTransformation, Identity # Types -Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] +Chunks_t: TypeAlias = ( + int + | tuple[int, ...] + | tuple[tuple[int, ...], ...] + | Mapping[Any, None | int | tuple[int, ...]] +) ScaleFactors_t = Sequence[dict[str, int] | int] Transform_s = AttrSchema(BaseTransformation, None) ATTRS_KEY = "spatialdata_attrs" -def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: +def _is_lazy_anndata(adata: AnnData) -> bool: + """Check if an AnnData object is lazily loaded. + + Lazy AnnData objects (from anndata.experimental.read_lazy) have obs/var + stored as xarray Dataset2D instead of pandas DataFrame. + + Parameters + ---------- + adata + The AnnData object to check. + + Returns + ------- + True if the AnnData is lazily loaded, False otherwise. + """ + # Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D) + if not isinstance(adata.obs, pd.DataFrame): + return True + return False + + +def _parse_transformations( + element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None +) -> None: _validate_mapping_to_coordinate_system_type(transformations) transformations_in_element = _get_transformations(element) if ( @@ -166,7 +180,9 @@ def parse( if transformations: transformations = transformations.copy() if "name" in kwargs: - raise ValueError("The `name` argument is not (yet) supported for raster data.") + raise ValueError( + "The `name` argument is not (yet) supported for raster data." + ) # if dims is specified inside the data, get the value of dims from the data if isinstance(data, DataArray): if not isinstance(data.data, DaskArray): # numpy -> dask @@ -214,13 +230,18 @@ def parse( if c_coords is not None: c_coords = _check_match_length_channels_c_dim(data, c_coords, cls.dims.dims) - if c_coords is not None and len(c_coords) != data.shape[cls.dims.dims.index("c")]: + if ( + c_coords is not None + and len(c_coords) != data.shape[cls.dims.dims.index("c")] + ): raise ValueError( f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'" f" with length {data.shape[cls.dims.dims.index('c')]}." ) - data = to_spatial_image(array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs) + data = to_spatial_image( + array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs + ) # parse transformations _parse_transformations(data, transformations) # convert to multiscale if needed @@ -275,12 +296,18 @@ def _(self, data: DataArray) -> None: @validate.register(DataTree) def _(self, data: DataTree) -> None: - for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True): + for j, k in zip( + data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True + ): if j != k: - raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.") + raise ValueError( + f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`." + ) name = {list(data[i].data_vars.keys())[0] for i in data} if len(name) != 1: - raise ValueError(f"Expected exactly one data variable for the datatree: found `{name}`.") + raise ValueError( + f"Expected exactly one data variable for the datatree: found `{name}`." + ) name = list(name)[0] for d in data: super().validate(data[d][name]) @@ -448,9 +475,14 @@ def validate(cls, data: GeoDataFrame) -> None: """ SUGGESTION = " Please use ShapesModel.parse() to construct data that is guaranteed to be valid." if cls.GEOMETRY_KEY not in data: - raise KeyError(f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." + SUGGESTION) + raise KeyError( + f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." + + SUGGESTION + ) if not isinstance(data[cls.GEOMETRY_KEY], GeoSeries): - raise ValueError(f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION) + raise ValueError( + f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION + ) if len(data[cls.GEOMETRY_KEY]) == 0: raise ValueError(f"Column `{cls.GEOMETRY_KEY}` is empty." + SUGGESTION) geom_ = data[cls.GEOMETRY_KEY].values[0] @@ -475,7 +507,10 @@ def validate(cls, data: GeoDataFrame) -> None: "please correct the radii of the circles before calling the parser function.", ) if cls.TRANSFORM_KEY not in data.attrs: - raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION) + raise ValueError( + f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + + SUGGESTION + ) if len(data) > 0: n = data.geometry.iloc[0]._ndim if n != 2: @@ -572,7 +607,9 @@ def parse(cls, data: Any, **kwargs: Any) -> GeoDataFrame: def _( cls, data: np.ndarray, # type: ignore[type-arg] - geometry: Literal[0, 3, 6], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] + geometry: Literal[ + 0, 3, 6 + ], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] offsets: tuple[ArrayLike, ...] | None = None, radius: float | ArrayLike | None = None, index: ArrayLike | None = None, @@ -583,7 +620,9 @@ def _( geo_df = GeoDataFrame({"geometry": data}) if GeometryType(geometry).name == "POINT": if radius is None: - raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") + raise ValueError( + "If `geometry` is `Circles`, `radius` must be provided." + ) geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -610,7 +649,9 @@ def _( geo_df = GeoDataFrame({"geometry": gc.geoms}) if isinstance(geo_df["geometry"].iloc[0], Point): if radius is None: - raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") + raise ValueError( + "If `geometry` is `Circles`, `radius` must be provided." + ) geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -627,7 +668,10 @@ def _( ) -> GeoDataFrame: if "geometry" not in data.columns: raise ValueError("`geometry` column not found in `GeoDataFrame`.") - if isinstance(data["geometry"].iloc[0], Point) and cls.RADIUS_KEY not in data.columns: + if ( + isinstance(data["geometry"].iloc[0], Point) + and cls.RADIUS_KEY not in data.columns + ): raise ValueError(f"Column `{cls.RADIUS_KEY}` not found.") _parse_transformations(data, transformations) cls.validate(data) @@ -667,7 +711,8 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError(f"Column `{ax}` must be of type `int` or `float`.") if cls.TRANSFORM_KEY not in data.attrs: raise ValueError( - f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION + f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + + SUGGESTION ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] @@ -750,11 +795,15 @@ def _( if annotation is not None: if feature_key is not None: - df_dict[feature_key] = annotation[feature_key].astype(str).astype("category") + df_dict[feature_key] = ( + annotation[feature_key].astype(str).astype("category") + ) if instance_key is not None: df_dict[instance_key] = annotation[instance_key] if Z not in axes and Z in annotation.columns: - logger.info(f"Column `{Z}` in `annotation` will be ignored since the data is 2D.") + logger.info( + f"Column `{Z}` in `annotation` will be ignored since the data is 2D." + ) for c in set(annotation.columns) - {feature_key, instance_key, X, Y, Z}: df_dict[c] = annotation[c] @@ -793,7 +842,9 @@ def _( if "sort" not in kwargs: index_monotonically_increasing = data.index.is_monotonic_increasing if not isinstance(index_monotonically_increasing, bool): - index_monotonically_increasing = index_monotonically_increasing.compute() + index_monotonically_increasing = ( + index_monotonically_increasing.compute() + ) sort = index_monotonically_increasing else: sort = kwargs["sort"] @@ -831,7 +882,9 @@ def _( if data[feature_key].dtype.name == "category": table[feature_key] = data[feature_key] else: - table[feature_key] = data[feature_key].astype(str).astype("category") + table[feature_key] = ( + data[feature_key].astype(str).astype("category") + ) if instance_key is not None: table[instance_key] = data[instance_key] for c in [X, Y, Z]: @@ -891,9 +944,13 @@ def _add_metadata_and_validate( # It also just changes the state of the series, so it is not a big deal. if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known: try: - data[c] = data[c].cat.set_categories(data[c].compute().cat.categories) + data[c] = data[c].cat.set_categories( + data[c].compute().cat.categories + ) except ValueError: - logger.info(f"Column `{c}` contains unknown categories. Consider casting it.") + logger.info( + f"Column `{c}` contains unknown categories. Consider casting it." + ) _parse_transformations(data, transformations) cls.validate(data) @@ -907,7 +964,9 @@ class TableModel: INSTANCE_KEY = "instance_key" ATTRS_KEY = ATTRS_KEY - def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: + def _validate_set_region_key( + self, data: AnnData, region_key: str | None = None + ) -> None: """ Validate the region key in table.uns or set a new region key as the region key column. @@ -947,7 +1006,9 @@ def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) raise ValueError(f"'{region_key}' column not present in table.obs") attrs[self.REGION_KEY_KEY] = region_key - def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: + def _validate_set_instance_key( + self, data: AnnData, instance_key: str | None = None + ) -> None: """ Validate the instance_key in table.uns or set a new instance_key as the instance_key column. @@ -991,7 +1052,9 @@ def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = N if instance_key in data.obs: attrs[self.INSTANCE_KEY] = instance_key else: - raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") + raise ValueError( + f"Instance key column '{instance_key}' not found in table.obs." + ) def _validate_table_annotation_metadata(self, data: AnnData) -> None: """ @@ -1026,16 +1089,33 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: attr = data.uns[ATTRS_KEY] if "region" not in attr: - raise ValueError(f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) + raise ValueError( + f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION + ) if "region_key" not in attr: - raise ValueError(f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) + raise ValueError( + f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION + ) if "instance_key" not in attr: - raise ValueError(f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) + raise ValueError( + f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION + ) if attr[self.REGION_KEY_KEY] not in data.obs: - raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") + raise ValueError( + f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column." + ) if attr[self.INSTANCE_KEY] not in data.obs: - raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") + raise ValueError( + f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column." + ) + + # Skip detailed dtype/value validation for lazy-loaded AnnData + # These checks would trigger data loading, defeating the purpose of lazy loading + # Validation will occur when data is actually computed/accessed + if _is_lazy_anndata(data): + return + if ( (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) not in [ @@ -1049,26 +1129,41 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: "O", ] and not pd.api.types.is_string_dtype(data.obs[attr[self.INSTANCE_KEY]]) - or (dtype == "O" and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) is not str) + or ( + dtype == "O" + and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) + is not str + ) ): dtype = dtype if dtype != "O" else val_dtype raise TypeError( f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for " f"instance_key column in obs. Dtype found to be {dtype}" ) - expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] + expected_regions = ( + attr[self.REGION_KEY] + if isinstance(attr[self.REGION_KEY], list) + else [attr[self.REGION_KEY]] + ) found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist() if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: - raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") + raise ValueError( + f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match." + ) # Warning for object/string columns with NaN in region_key or instance_key instance_key = attr[self.INSTANCE_KEY] region_key = attr[self.REGION_KEY_KEY] - for key_name, key_value in [("region_key", region_key), ("instance_key", instance_key)]: + for key_name, key_value in [ + ("region_key", region_key), + ("instance_key", instance_key), + ]: if key_value in data.obs: col = data.obs[key_value] col_dtype = col.dtype - if (col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype)) and col.isna().any(): + if ( + col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype) + ) and col.isna().any(): logger.warning( f"The {key_name} column '{key_value}' is of {col_dtype} type and contains NaN values. " "After writing and reading with AnnData, NaN values may (depending on the AnnData version) " @@ -1099,6 +1194,10 @@ def validate( if ATTRS_KEY not in data.uns: return data + # Check if this is a lazy-loaded AnnData (from anndata.experimental.read_lazy) + # Lazy AnnData has xarray-based obs/var, which requires different validation + is_lazy = _is_lazy_anndata(data) + _, region_key, instance_key = get_table_keys(data) if region_key is not None: if region_key not in data.obs: @@ -1106,7 +1205,10 @@ def validate( f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse " f"using TableModel.parse(adata)." ) - if not isinstance(data.obs[region_key].dtype, CategoricalDtype): + # Skip dtype validation for lazy tables (would require loading data) + if not is_lazy and not isinstance( + data.obs[region_key].dtype, CategoricalDtype + ): raise ValueError( f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." ) @@ -1116,8 +1218,11 @@ def validate( f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse" f" using TableModel.parse(adata)." ) - if data.obs[instance_key].isnull().values.any(): - raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") + # Skip null check for lazy tables (would require loading data) + if not is_lazy and data.obs[instance_key].isnull().values.any(): + raise ValueError( + "`table.obs[instance_key]` must not contain null values, but it does." + ) self._validate_table_annotation_metadata(data) @@ -1154,7 +1259,9 @@ def parse( """ validate_table_attr_keys(adata) # either all live in adata.uns or all be passed in as argument - n_args = sum([region is not None, region_key is not None, instance_key is not None]) + n_args = sum( + [region is not None, region_key is not None, instance_key is not None] + ) if n_args == 0: if cls.ATTRS_KEY not in adata.uns: # table not annotating any element @@ -1183,7 +1290,9 @@ def parse( region = region.tolist() region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): - raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") + raise ValueError( + f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values." + ) adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key @@ -1194,7 +1303,9 @@ def parse( grouped = adata.obs.groupby(region_key, observed=True) grouped_size = grouped.size() grouped_nunique = grouped.nunique() - not_unique = grouped_size[grouped_size != grouped_nunique[instance_key]].index.tolist() + not_unique = grouped_size[ + grouped_size != grouped_nunique[instance_key] + ].index.tolist() if not_unique: raise ValueError( f"Instance key column for region(s) `{', '.join(not_unique)}` does not contain only unique values" @@ -1305,6 +1416,11 @@ def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]: ) annotated_regions = region_key_column.unique().tolist() else: - annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist() + annotated_regions = ( + table.obs[region_key] + .cat.remove_unused_categories() + .cat.categories.unique() + .tolist() + ) assert isinstance(annotated_regions, list) return annotated_regions diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index af028d29..241336df 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -94,7 +94,11 @@ def test_shapes( # add a mixed Polygon + MultiPolygon element shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) - shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding) + shapes.write( + tmpdir, + sdata_formats=sdata_container_format, + shapes_geometry_encoding=geometry_encoding, + ) sdata = SpatialData.read(tmpdir) if geometry_encoding == "WKB": @@ -102,7 +106,9 @@ def test_shapes( else: # convert each Polygon to a MultiPolygon mixed_multipolygon = shapes["mixed"].assign( - geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g) + geometry=lambda df: df.geometry.apply( + lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g + ) ) assert sdata["mixed"].equals(mixed_multipolygon) assert not sdata["mixed"].equals(shapes["mixed"]) @@ -139,7 +145,9 @@ def test_shapes_geometry_encoding_write_element( # Write each shape element - should use global setting for shape_name in shapes.shapes: - empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format) + empty_sdata.write_element( + shape_name, sdata_formats=sdata_container_format + ) # Verify the encoding metadata in the parquet file parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet" @@ -220,8 +228,12 @@ def test_multiple_tables( tables: list[AnnData], sdata_container_format: SpatialDataContainerFormatType, ) -> None: - sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) - self._test_table(tmp_path, sdata_tables, sdata_container_format=sdata_container_format) + sdata_tables = SpatialData( + tables={str(i): tables[i] for i in range(len(tables))} + ) + self._test_table( + tmp_path, sdata_tables, sdata_container_format=sdata_container_format + ) def test_roundtrip( self, @@ -252,7 +264,9 @@ def test_incremental_io_list_of_elements( assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - shapes.write_element(["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format) + shapes.write_element( + ["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format + ) assert "shapes/new_shapes0" in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" in shapes.elements_paths_on_disk() @@ -367,7 +381,9 @@ def test_incremental_io_on_disk( ValueError, match=match, ): - sdata.write_element(name, overwrite=True, sdata_formats=sdata_container_format) + sdata.write_element( + name, overwrite=True, sdata_formats=sdata_container_format + ) if workaround == 1: new_name = f"{name}_new_place" @@ -398,7 +414,9 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name, sdata_formats=sdata_container_format) - def test_io_and_lazy_loading_points(self, points, sdata_container_format: SpatialDataContainerFormatType): + def test_io_and_lazy_loading_points( + self, points, sdata_container_format: SpatialDataContainerFormatType + ): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") points.write(f, sdata_formats=sdata_container_format) @@ -407,7 +425,9 @@ def test_io_and_lazy_loading_points(self, points, sdata_container_format: Spatia sdata2 = SpatialData.read(f) assert len(get_dask_backing_files(sdata2)) > 0 - def test_io_and_lazy_loading_raster(self, images, labels, sdata_container_format: SpatialDataContainerFormatType): + def test_io_and_lazy_loading_raster( + self, images, labels, sdata_container_format: SpatialDataContainerFormatType + ): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -457,9 +477,13 @@ def test_replace_transformation_on_disk_non_raster( with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") sdata.write(f, sdata_formats=sdata_container_format) - t0 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) + t0 = get_transformation( + SpatialData.read(f).__getattribute__(k)[elem_name] + ) assert isinstance(t0, Identity) - set_transformation(sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata) + set_transformation( + sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata + ) t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) @@ -470,10 +494,16 @@ def test_write_overwrite_fails_when_no_zarr_store( f = Path(tmpdir) / "data.zarr" f.mkdir() old_data = SpatialData() - with pytest.raises(ValueError, match="The target file path specified already exists"): + with pytest.raises( + ValueError, match="The target file path specified already exists" + ): old_data.write(f, sdata_formats=sdata_container_format) - with pytest.raises(ValueError, match="The target file path specified already exists"): - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + with pytest.raises( + ValueError, match="The target file path specified already exists" + ): + full_sdata.write( + f, overwrite=True, sdata_formats=sdata_container_format + ) def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( self, @@ -506,7 +536,9 @@ def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( match=r"Details: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object", ): - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f, overwrite=True, sdata_formats=sdata_container_format + ) def test_overwrite_fails_when_zarr_store_present( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType @@ -526,7 +558,9 @@ def test_overwrite_fails_when_zarr_store_present( ValueError, match=r"Details: the target path either contains, coincides or is contained in the current Zarr store", ): - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f, overwrite=True, sdata_formats=sdata_container_format + ) # support for overwriting backed sdata has been temporarily removed # with tempfile.TemporaryDirectory() as tmpdir: @@ -547,9 +581,7 @@ def test_overwrite_fails_when_zarr_store_present( def test_overwrite_fails_onto_non_zarr_file( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType ): - ERROR_MESSAGE = ( - "The target file path specified already exists, and it has been detected to not be a Zarr store." - ) + ERROR_MESSAGE = "The target file path specified already exists, and it has been detected to not be a Zarr store." with tempfile.TemporaryDirectory() as tmpdir: f0 = os.path.join(tmpdir, "test.txt") with open(f0, "w"): @@ -562,13 +594,17 @@ def test_overwrite_fails_onto_non_zarr_file( ValueError, match=ERROR_MESSAGE, ): - full_sdata.write(f0, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f0, overwrite=True, sdata_formats=sdata_container_format + ) f1 = os.path.join(tmpdir, "test.zarr") os.mkdir(f1) with pytest.raises(ValueError, match=ERROR_MESSAGE): full_sdata.write(f1, sdata_formats=sdata_container_format) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write(f1, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f1, overwrite=True, sdata_formats=sdata_container_format + ) def test_incremental_io_in_memory( @@ -606,7 +642,9 @@ def test_bug_rechunking_after_queried_raster(): # https://github.com/scverse/spatialdata-io/issues/117 ## single_scale = Image2DModel.parse(RNG.random((100, 10, 10)), chunks=(5, 5, 5)) - multi_scale = Image2DModel.parse(RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5)) + multi_scale = Image2DModel.parse( + RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5) + ) images = {"single_scale": single_scale, "multi_scale": multi_scale} sdata = SpatialData(images=images) queried = sdata.query.bounding_box( @@ -621,7 +659,9 @@ def test_bug_rechunking_after_queried_raster(): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: +def test_self_contained( + full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained assert full_sdata.is_self_contained() description = full_sdata.elements_are_self_contained() @@ -645,7 +685,10 @@ def test_self_contained(full_sdata: SpatialData, sdata_container_format: Spatial # because of the images, labels and points description = sdata2.elements_are_self_contained() for element_name, self_contained in description.items(): - if any(element_name.startswith(prefix) for prefix in ["image", "labels", "points"]): + if any( + element_name.startswith(prefix) + for prefix in ["image", "labels", "points"] + ): assert not self_contained else: assert self_contained @@ -678,7 +721,11 @@ def test_self_contained(full_sdata: SpatialData, sdata_container_format: Spatial assert not sdata2.is_self_contained() description = sdata2.elements_are_self_contained() assert description["combined"] is False - assert all(description[element_name] for element_name in description if element_name != "combined") + assert all( + description[element_name] + for element_name in description + if element_name != "combined" + ) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) @@ -690,7 +737,9 @@ def test_symmetric_difference_with_zarr_store( full_sdata.write(f, sdata_formats=sdata_container_format) # the list of element on-disk and in-memory is the same - only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() + only_in_memory, only_on_disk = ( + full_sdata._symmetric_difference_with_zarr_store() + ) assert len(only_in_memory) == 0 assert len(only_on_disk) == 0 @@ -706,7 +755,9 @@ def test_symmetric_difference_with_zarr_store( del full_sdata.tables["table"] # now the list of element on-disk and in-memory is different - only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() + only_in_memory, only_on_disk = ( + full_sdata._symmetric_difference_with_zarr_store() + ) assert set(only_in_memory) == { "images/new_image2d", "labels/new_labels2d", @@ -724,13 +775,17 @@ def test_symmetric_difference_with_zarr_store( @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_change_path_of_subset(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: +def test_change_path_of_subset( + full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: """A subset SpatialData object has not Zarr path associated, show that we can reassign the path""" with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f, sdata_formats=sdata_container_format) - subset = full_sdata.subset(["image2d", "labels2d", "points_0", "circles", "table"]) + subset = full_sdata.subset( + ["image2d", "labels2d", "points_0", "circles", "table"] + ) assert subset.path is None subset.path = Path(f) @@ -795,7 +850,9 @@ def test_incremental_io_valid_name(full_sdata: SpatialData) -> None: @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_incremental_io_attrs(points: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: +def test_incremental_io_attrs( + points: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") my_attrs = {"a": "b", "c": 1} @@ -822,7 +879,9 @@ def test_incremental_io_attrs(points: SpatialData, sdata_container_format: Spati cached_sdata_blobs = blobs() -@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) +@pytest.mark.parametrize( + "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] +) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_delete_element_from_disk( full_sdata, @@ -830,7 +889,9 @@ def test_delete_element_from_disk( sdata_container_format: SpatialDataContainerFormatType, ) -> None: # can't delete an element for a SpatialData object without associated Zarr store - with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): + with pytest.raises( + ValueError, match="The SpatialData object is not backed by a Zarr store." + ): full_sdata.delete_element_from_disk("image2d") with tempfile.TemporaryDirectory() as tmpdir: @@ -858,7 +919,9 @@ def test_delete_element_from_disk( # can delete an element present both in-memory and on-disk full_sdata.delete_element_from_disk(element_name) - only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() + only_in_memory, only_on_disk = ( + full_sdata._symmetric_difference_with_zarr_store() + ) element_type = full_sdata._element_type_from_element_name(element_name) element_path = f"{element_type}/{element_name}" assert element_path in only_in_memory @@ -873,7 +936,9 @@ def test_delete_element_from_disk( assert element_path not in on_disk -@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) +@pytest.mark.parametrize( + "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] +) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_element_already_on_disk_different_type( full_sdata, @@ -927,7 +992,9 @@ def test_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( + iter(_get_shapes().values()) + ) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") with pytest.raises(ValueError, match="Name (must|cannot)"): @@ -938,7 +1005,9 @@ def test_writing_valid_table_name_invalid_table(tmp_path: Path): # also try with a valid table name but invalid table # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() - invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) + invalid_sdata.tables.data["valid_name"] = AnnData( + np.array([[0]]), layers={"invalid name": np.array([[0]])} + ) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write(tmp_path / "data.zarr") @@ -951,7 +1020,9 @@ def test_incremental_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( + iter(_get_shapes().values()) + ) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") for element_type in ["images", "labels", "points", "shapes", "tables"]: @@ -966,7 +1037,9 @@ def test_incremental_writing_valid_table_name_invalid_table(tmp_path: Path): # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() invalid_sdata.write(tmp_path / "data2.zarr") - invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) + invalid_sdata.tables.data["valid_name"] = AnnData( + np.array([[0]]), layers={"invalid name": np.array([[0]])} + ) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write_element("valid_name") @@ -986,13 +1059,19 @@ def test_reading_invalid_name(tmp_path: Path): ) valid_sdata.write(tmp_path / "data.zarr") # Circumvent validation at construction time and check validation happens again at writing time. - (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") + (tmp_path / "data.zarr/points" / points_name).rename( + tmp_path / "data.zarr/points" / "has whitespace" + ) # This one is not allowed on windows - (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@") + (tmp_path / "data.zarr/shapes" / shapes_name).rename( + tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@" + ) # We do this as the key of the element is otherwise not in the consolidated metadata, leading to an error. valid_sdata.write_consolidated_metadata() - with pytest.raises(ValidationError, match="Cannot construct SpatialData") as exc_info: + with pytest.raises( + ValidationError, match="Cannot construct SpatialData" + ) as exc_info: read_zarr(tmp_path / "data.zarr") actual_message = str(exc_info.value) @@ -1005,10 +1084,14 @@ def test_reading_invalid_name(tmp_path: Path): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: SpatialDataContainerFormatType): +def test_write_store_unconsolidated_and_read( + full_sdata, sdata_container_format: SpatialDataContainerFormatType +): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" - full_sdata.write(path, consolidate_metadata=False, sdata_formats=sdata_container_format) + full_sdata.write( + path, consolidate_metadata=False, sdata_formats=sdata_container_format + ) group = zarr.open_group(path, mode="r") assert group.metadata.consolidated_metadata is None @@ -1017,7 +1100,9 @@ def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: SpatialDataContainerFormatType): +def test_can_read_sdata_with_reconsolidation( + full_sdata, sdata_container_format: SpatialDataContainerFormatType +): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" full_sdata.write(path, sdata_formats=sdata_container_format) @@ -1107,3 +1192,84 @@ def test_sdata_with_nan_in_obs() -> None: # After round-trip, NaN in object-dtype column becomes string "nan" assert sdata2["table"].obs["column_only_region1"].iloc[1] == "nan" assert np.isnan(sdata2["table"].obs["column_only_region2"].iloc[0]) + + +class TestLazyTableLoading: + """Tests for lazy table loading functionality. + + Lazy loading uses anndata.experimental.read_lazy() to keep large tables + out of memory until needed. This is particularly useful for MSI data + where tables can contain millions of pixels. + """ + + @pytest.fixture + def sdata_with_table(self) -> SpatialData: + """Create a SpatialData object with a simple table for testing.""" + table = TableModel.parse( + AnnData( + X=np.random.rand(100, 50), + obs=pd.DataFrame( + { + "region": pd.Categorical(["region1"] * 100), + "instance": np.arange(100), + } + ), + ), + region_key="region", + instance_key="instance", + region="region1", + ) + return SpatialData(tables={"test_table": table}) + + def test_lazy_read_basic(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=True reads tables without loading into memory.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=True + try: + sdata_lazy = SpatialData.read(path, lazy=True) + + # Table should be present + assert "test_table" in sdata_lazy.tables + + # Check that X is a lazy array (dask or similar) + # Lazy AnnData from read_lazy uses dask arrays + table = sdata_lazy.tables["test_table"] + assert hasattr(table, "X") + + except ImportError: + # If anndata.experimental.read_lazy is not available, skip + pytest.skip("anndata.experimental.read_lazy not available") + + def test_lazy_false_loads_normally(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=False (default) loads tables into memory normally.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=False (default) + sdata_normal = SpatialData.read(path, lazy=False) + + # Table should be present and loaded normally + assert "test_table" in sdata_normal.tables + table = sdata_normal.tables["test_table"] + + # X should be a numpy array or scipy sparse matrix (in-memory) + import scipy.sparse as sp + + assert isinstance(table.X, np.ndarray | sp.spmatrix) + + def test_read_zarr_lazy_parameter(self, sdata_with_table: SpatialData) -> None: + """Test that read_zarr function accepts lazy parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Test read_zarr directly with lazy parameter + try: + sdata = read_zarr(path, lazy=True) + assert "test_table" in sdata.tables + except ImportError: + pytest.skip("anndata.experimental.read_lazy not available") From 44d4f45045f362b46cfcbabf26df51c012442247 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:59:20 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/spatialdata.py | 340 +++++++++------------------ src/spatialdata/_io/io_zarr.py | 38 +-- src/spatialdata/models/models.py | 184 ++++----------- tests/io/test_readwrite.py | 167 ++++--------- 4 files changed, 204 insertions(+), 525 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 869683ba..079ca28b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -22,15 +22,33 @@ from zarr.errors import GroupNotFoundError from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables -from spatialdata._core.validation import (check_all_keys_case_insensitively_unique, check_target_region_column_symmetry, - check_valid_name, raise_validation_errors, validate_table_attr_keys) +from spatialdata._core.validation import ( + check_all_keys_case_insensitively_unique, + check_target_region_column_symmetry, + check_valid_name, + raise_validation_errors, + validate_table_attr_keys, +) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import _deprecation_alias -from spatialdata.models import (Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel, - TableModel, get_model, get_table_keys) -from spatialdata.models._utils import (SpatialElement, convert_region_column_to_categorical, get_axes_names, - set_channel_names) +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + PointsModel, + ShapesModel, + TableModel, + get_model, + get_table_keys, +) +from spatialdata.models._utils import ( + SpatialElement, + convert_region_column_to_categorical, + get_axes_names, + set_channel_names, +) if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest @@ -119,11 +137,7 @@ def __init__( self._tables: Tables = Tables(shared_keys=self._shared_keys) self.attrs = attrs if attrs else {} # type: ignore[assignment] - element_names = list( - chain.from_iterable( - [e.keys() for e in [images, labels, points, shapes] if e is not None] - ) - ) + element_names = list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) if len(element_names) != len(set(element_names)): duplicates = {x for x in element_names if element_names.count(x) > 1} @@ -249,9 +263,7 @@ def get_region_key_column(table: AnnData) -> pd.Series: _, region_key, _ = get_table_keys(table) if table.obs.get(region_key) is not None: return table.obs[region_key] - raise KeyError( - f"{region_key} is set as region key column. However the column is not found in table.obs." - ) + raise KeyError(f"{region_key} is set as region key column. However the column is not found in table.obs.") @staticmethod def get_instance_key_column(table: AnnData) -> pd.Series: @@ -276,13 +288,9 @@ def get_instance_key_column(table: AnnData) -> pd.Series: _, _, instance_key = get_table_keys(table) if table.obs.get(instance_key) is not None: return table.obs[instance_key] - raise KeyError( - f"{instance_key} is set as instance key column. However the column is not found in table.obs." - ) + raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") - def set_channel_names( - self, element_name: str, channel_names: str | list[str], write: bool = False - ) -> None: + def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: """Set the channel names for an image `SpatialElement` in the `SpatialData` object. This method will overwrite the element in memory with the same element, but with new channel names. @@ -300,9 +308,7 @@ def set_channel_names( Whether to overwrite the channel metadata on disk (lightweight operation). This will not rewrite the pixel data itself (heavy operation). """ - self.images[element_name] = set_channel_names( - self.images[element_name], channel_names - ) + self.images[element_name] = set_channel_names(self.images[element_name], channel_names) if write: self.write_channel_names(element_name) @@ -380,9 +386,7 @@ def _change_table_annotation_target( If provided region_key is not present in table.obs. """ attrs = table.uns[TableModel.ATTRS_KEY] - table_region_key = ( - region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) - ) + table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) TableModel()._validate_set_region_key(table, region_key) TableModel()._validate_set_instance_key(table, instance_key) @@ -390,9 +394,7 @@ def _change_table_annotation_target( attrs[TableModel.REGION_KEY] = region @staticmethod - def update_annotated_regions_metadata( - table: AnnData, region_key: str | None = None - ) -> AnnData: + def update_annotated_regions_metadata(table: AnnData, region_key: str | None = None) -> AnnData: """ Update the annotation target of the table using the region_key column in table.obs. @@ -415,9 +417,7 @@ def update_annotated_regions_metadata( """ attrs = table.uns.get(TableModel.ATTRS_KEY) if attrs is None: - raise ValueError( - "The table has no annotation metadata. Please parse the table using `TableModel.parse`." - ) + raise ValueError("The table has no annotation metadata. Please parse the table using `TableModel.parse`.") region_key = region_key if region_key else attrs[TableModel.REGION_KEY_KEY] if attrs[TableModel.REGION_KEY_KEY] != region_key: attrs[TableModel.REGION_KEY_KEY] = region_key @@ -465,20 +465,14 @@ def set_table_annotates_spatialelement( isinstance(region, list | pd.Series) and not all(region_element in element_names for region_element in region) ): - raise ValueError( - f"Annotation target '{region}' not present as SpatialElement in SpatialData object." - ) + raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.") if table.uns.get(TableModel.ATTRS_KEY): - self._change_table_annotation_target( - table, region, region_key, instance_key - ) + self._change_table_annotation_target(table, region, region_key, instance_key) elif isinstance(region_key, str) and isinstance(instance_key, str): self._set_table_annotation_target(table, region, region_key, instance_key) else: - raise TypeError( - "No current annotation metadata found. Please specify both region_key and instance_key." - ) + raise TypeError("No current annotation metadata found. Please specify both region_key and instance_key.") convert_region_column_to_categorical(table) @property @@ -589,17 +583,9 @@ def locate_element(self, element: SpatialElement) -> list[str]: found_element_name.append(element_name) if len(found) == 0: return [] - if any( - "/" in found_element_name[i] or "/" in found_element_type[i] - for i in range(len(found)) - ): - raise ValueError( - "Found an element name with a '/' character. This is not allowed." - ) - return [ - f"{found_element_type[i]}/{found_element_name[i]}" - for i in range(len(found)) - ] + if any("/" in found_element_name[i] or "/" in found_element_type[i] for i in range(len(found))): + raise ValueError("Found an element name with a '/' character. This is not allowed.") + return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] def filter_by_coordinate_system( self, @@ -697,9 +683,7 @@ def _filter_tables( if include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): tables[table_name] = table continue - if not include_orphan_tables and not table.uns.get( - TableModel.ATTRS_KEY - ): + if not include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): continue if table_name in names_tables_to_keep: tables[table_name] = table @@ -716,9 +700,7 @@ def _filter_tables( from spatialdata._core.query.relational_query import _filter_table_by_elements assert elements_dict is not None - table = _filter_table_by_elements( - table, elements_dict=elements_dict - ) + table = _filter_table_by_elements(table, elements_dict=elements_dict) if table is not None and len(table) != 0: tables[table_name] = table else: @@ -766,9 +748,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: for old_cs, new_cs in rename_dict.items(): if old_cs in transformations: random_suffix = hashlib.sha1(os.urandom(128)).hexdigest()[:8] - transformations[new_cs + random_suffix] = transformations.pop( - old_cs - ) + transformations[new_cs + random_suffix] = transformations.pop(old_cs) suffixes_to_replace.add(new_cs + random_suffix) # remove the random suffixes @@ -779,14 +759,10 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: new_transformations[cs] = transformations[cs_with_suffix] suffixes_to_replace.remove(cs_with_suffix) else: - new_transformations[cs_with_suffix] = transformations[ - cs_with_suffix - ] + new_transformations[cs_with_suffix] = transformations[cs_with_suffix] # set the new transformations - set_transformation( - element=element, transformation=new_transformations, set_all=True - ) + set_transformation(element=element, transformation=new_transformations, set_all=True) def transform_element_to_coordinate_system( self, @@ -814,18 +790,17 @@ def transform_element_to_coordinate_system( """ from spatialdata import transform from spatialdata.transformations import Sequence - from spatialdata.transformations.operations import (get_transformation, - get_transformation_between_coordinate_systems, - remove_transformation, set_transformation) + from spatialdata.transformations.operations import ( + get_transformation, + get_transformation_between_coordinate_systems, + remove_transformation, + set_transformation, + ) element = self.get(element_name) - t = get_transformation_between_coordinate_systems( - self, element, target_coordinate_system - ) + t = get_transformation_between_coordinate_systems(self, element, target_coordinate_system) if maintain_positioning: - transformed = transform( - element, transformation=t, maintain_positioning=maintain_positioning - ) + transformed = transform(element, transformation=t, maintain_positioning=maintain_positioning) else: d = get_transformation(element, get_all=True) assert isinstance(d, dict) @@ -861,9 +836,7 @@ def transform_element_to_coordinate_system( # since target_coordinate_system is in d, we have that t is a Sequence with only one transformation. assert isinstance(t, Sequence) assert len(t.transformations) == 1 - seq = get_transformation( - transformed, to_coordinate_system=target_coordinate_system - ) + seq = get_transformation(transformed, to_coordinate_system=target_coordinate_system) assert isinstance(seq, Sequence) assert len(seq.transformations) == 2 assert seq.transformations[1] is t.transformations[0] @@ -892,9 +865,7 @@ def transform_to_coordinate_system( ------- The transformed SpatialData. """ - sdata = self.filter_by_coordinate_system( - target_coordinate_system, filter_tables=False - ) + sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) elements: dict[str, dict[str, SpatialElement]] = {} for element_type, element_name, _ in sdata.gen_elements(): if element_type != "tables": @@ -930,9 +901,7 @@ def elements_are_self_contained(self) -> dict[str, bool]: description = {} for element_type, element_name, element in self.gen_elements(): element_path = self.path / element_type / element_name - description[element_name] = _is_element_self_contained( - element, element_path - ) + description[element_name] = _is_element_self_contained(element, element_path) return description def is_self_contained(self, element_name: str | None = None) -> bool: @@ -1049,12 +1018,8 @@ def _symmetric_difference_with_zarr_store(self) -> tuple[list[str], list[str]]: elements_in_sdata = self.elements_paths_in_memory() elements_in_zarr = self.elements_paths_on_disk() - elements_only_in_sdata = list( - set(elements_in_sdata).difference(set(elements_in_zarr)) - ) - elements_only_in_zarr = list( - set(elements_in_zarr).difference(set(elements_in_sdata)) - ) + elements_only_in_sdata = list(set(elements_in_sdata).difference(set(elements_in_zarr))) + elements_only_in_zarr = list(set(elements_in_zarr).difference(set(elements_in_sdata))) return elements_only_in_sdata, elements_only_in_zarr def _validate_can_safely_write_to_path( @@ -1069,9 +1034,7 @@ def _validate_can_safely_write_to_path( file_path = Path(file_path) if not isinstance(file_path, Path): - raise ValueError( - f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}." - ) + raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") # TODO: add test for this if os.path.exists(file_path): @@ -1093,28 +1056,24 @@ def _validate_can_safely_write_to_path( "Cannot overwrite. The target path of the write operation is in use. Please save the data to a " "different location. " ) - WORKAROUND = "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." + WORKAROUND = ( + "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." + ) if any(_backed_elements_contained_in_path(path=file_path, object=self)): raise ValueError( - ERROR_MSG - + "\nDetails: the target path contains one or more files that Dask use for " + ERROR_MSG + "\nDetails: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object." + WORKAROUND ) if self.path is not None and ( - _is_subfolder(parent=self.path, child=file_path) - or _is_subfolder(parent=file_path, child=self.path) + _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) ): - if saving_an_element and _is_subfolder( - parent=self.path, child=file_path - ): + if saving_an_element and _is_subfolder(parent=self.path, child=file_path): raise ValueError( - ERROR_MSG - + "\nDetails: the target path in which to save an element is a subfolder " + ERROR_MSG + "\nDetails: the target path in which to save an element is a subfolder " "of the current Zarr store." + WORKAROUND ) raise ValueError( - ERROR_MSG - + "\nDetails: the target path either contains, coincides or is contained in" + ERROR_MSG + "\nDetails: the target path either contains, coincides or is contained in" " the current Zarr store." + WORKAROUND ) @@ -1139,9 +1098,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, - sdata_formats: ( - SpatialDataFormatType | list[SpatialDataFormatType] | None - ) = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1203,9 +1160,7 @@ def write( store = _resolve_zarr_store(file_path) zarr_format = parsed["SpatialData"].zarr_format - zarr_group = zarr.create_group( - store=store, overwrite=overwrite, zarr_format=zarr_format - ) + zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) store.close() @@ -1300,9 +1255,7 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - sdata_formats: ( - SpatialDataFormatType | list[SpatialDataFormatType] | None - ) = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1347,9 +1300,7 @@ def write_element( self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") if self.path is None: raise ValueError( @@ -1363,15 +1314,11 @@ def write_element( element_type = _element_type break if element_type is None: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") if element_type == "tables": validate_table_attr_keys(element) - self._check_element_not_on_disk_with_different_type( - element_type=element_type, element_name=element_name - ) + self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) self._write_element( element=element, @@ -1431,16 +1378,10 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: raise ValueError("The SpatialData object is not backed by a Zarr store.") on_disk = self.elements_paths_on_disk() - one_disk_names = [ - self._element_type_and_name_from_element_path(path)[1] for path in on_disk - ] + one_disk_names = [self._element_type_and_name_from_element_path(path)[1] for path in on_disk] in_memory = self.elements_paths_in_memory() - in_memory_names = [ - self._element_type_and_name_from_element_path(path)[1] for path in in_memory - ] - only_in_memory_names = list( - set(in_memory_names).difference(set(one_disk_names)) - ) + in_memory_names = [self._element_type_and_name_from_element_path(path)[1] for path in in_memory] + only_in_memory_names = list(set(in_memory_names).difference(set(one_disk_names))) only_on_disk_names = list(set(one_disk_names).difference(set(in_memory_names))) ERROR_MESSAGE = f"Element {element_name} is not found in the Zarr store associated with the SpatialData object." @@ -1453,25 +1394,19 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if found: _element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type( - element_type=_element_type, element_name=element_name - ) + self._check_element_not_on_disk_with_different_type(element_type=_element_type, element_name=element_name) element_type = None on_disk = self.elements_paths_on_disk() for path in on_disk: - _element_type, _element_name = ( - self._element_type_and_name_from_element_path(path) - ) + _element_type, _element_name = self._element_type_and_name_from_element_path(path) if _element_name == element_name: element_type = _element_type break assert element_type is not None file_path_of_element = self.path / element_type / element_name - if any( - _backed_elements_contained_in_path(path=file_path_of_element, object=self) - ): + if any(_backed_elements_contained_in_path(path=file_path_of_element, object=self)): raise ValueError( "The file path specified is a parent directory of one or more files used for backing for one or " "more elements in the SpatialData object. Deleting the data would corrupt the SpatialData object." @@ -1488,14 +1423,10 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if self.has_consolidated_metadata(): self.write_consolidated_metadata() - def _check_element_not_on_disk_with_different_type( - self, element_type: str, element_name: str - ) -> None: + def _check_element_not_on_disk_with_different_type(self, element_type: str, element_name: str) -> None: only_on_disk = self.elements_paths_on_disk() for disk_path in only_on_disk: - disk_element_type, disk_element_name = ( - self._element_type_and_name_from_element_path(disk_path) - ) + disk_element_type, disk_element_name = self._element_type_and_name_from_element_path(disk_path) if disk_element_name == element_name and disk_element_type != element_type: raise ValueError( f"Element {element_name} is found in the Zarr store as a {disk_element_type}, but it is found " @@ -1520,9 +1451,7 @@ def has_consolidated_metadata(self) -> bool: store.close() return return_value - def _validate_can_write_metadata_on_element( - self, element_name: str - ) -> tuple[str, SpatialElement | AnnData] | None: + def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[str, SpatialElement | AnnData] | None: """Validate if metadata can be written on an element, returns None if it cannot be written.""" from spatialdata._io._utils import _is_element_self_contained from spatialdata._io.io_zarr import _group_for_element_exists @@ -1545,9 +1474,7 @@ def _validate_can_write_metadata_on_element( element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type( - element_type=element_type, element_name=element_name - ) + self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) # check if the element exists in the Zarr storage if not _group_for_element_exists( @@ -1566,9 +1493,7 @@ def _validate_can_write_metadata_on_element( # warn the users if the element is not self-contained, that is, it is Dask-backed by files outside the Zarr # group for the element element_zarr_path = Path(self.path) / element_type / element_name - if not _is_element_self_contained( - element=element, element_path=element_zarr_path - ): + if not _is_element_self_contained(element=element, element_path=element_zarr_path): logger.info( f"Element {element_type}/{element_name} is not self-contained. The metadata will be" " saved to the Zarr group of the element in the SpatialData Zarr store. The data outside the element " @@ -1591,9 +1516,7 @@ def write_channel_names(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1620,9 +1543,7 @@ def write_channel_names(self, element_name: str | None = None) -> None: overwrite_channel_names(element_group, element) else: - raise ValueError( - f"Can't set channel names for element of type '{element_type}'." - ) + raise ValueError(f"Can't set channel names for element of type '{element_type}'.") def write_transformations(self, element_name: str | None = None) -> None: """ @@ -1638,9 +1559,7 @@ def write_transformations(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1671,9 +1590,7 @@ def write_transformations(self, element_name: str | None = None) -> None: from spatialdata._io._utils import overwrite_coordinate_transformations_raster from spatialdata._io.format import RasterFormats - raster_format = RasterFormats[ - element_group.metadata.attributes["spatialdata_attrs"]["version"] - ] + raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] overwrite_coordinate_transformations_raster( group=element_group, axes=axes, @@ -1695,9 +1612,7 @@ def _element_type_from_element_name(self, element_name: str) -> str: self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") located = self.locate_element(element) element_type = None @@ -1711,9 +1626,7 @@ def _element_type_from_element_name(self, element_name: str) -> str: assert element_type is not None return element_type - def _element_type_and_name_from_element_path( - self, element_path: str - ) -> tuple[str, str]: + def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[str, str]: element_type, element_name = element_path.split("/") return element_type, element_name @@ -1726,27 +1639,19 @@ def write_attrs( from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType - sdata_format = ( - sdata_format - if sdata_format is not None - else CurrentSpatialDataContainerFormat() - ) + sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat() assert isinstance(sdata_format, SpatialDataContainerFormatType) store = None if zarr_group is None: - assert ( - self.is_backed() - ), "The SpatialData object must be backed by a Zarr store to write attrs." + assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." store = _resolve_zarr_store(self.path) zarr_group = zarr.open_group(store=store, mode="r+") version = sdata_format.spatialdata_format_version version_specific_attrs = sdata_format.attrs_to_dict() - attrs_to_write = { - "spatialdata_attrs": {"version": version} | version_specific_attrs - } | self.attrs + attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs try: zarr_group.attrs.put(attrs_to_write) @@ -1795,9 +1700,7 @@ def write_metadata( if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") if write_attrs: self.write_attrs(sdata_format=sdata_format) @@ -1840,9 +1743,7 @@ def get_attrs( the value of `return_as`. """ - def _flatten_mapping( - m: Mapping[str, Any], parent_key: str = "", sep: str = "_" - ) -> dict[str, Any]: + def _flatten_mapping(m: Mapping[str, Any], parent_key: str = "", sep: str = "_") -> dict[str, Any]: items: list[tuple[str, Any]] = [] for k, v in m.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k @@ -1887,9 +1788,7 @@ def _flatten_mapping( except Exception as e: raise ValueError(f"Failed to convert data to DataFrame: {e}") from e - raise ValueError( - f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None." - ) + raise ValueError(f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None.") @property def tables(self) -> Tables: @@ -2027,8 +1926,7 @@ def _non_empty_elements(self) -> list[str]: return [ element for element in all_elements - if (getattr(self, element) is not None) - and (len(getattr(self, element)) > 0) + if (getattr(self, element) is not None) and (len(getattr(self, element)) > 0) ] def __repr__(self) -> str: @@ -2066,9 +1964,7 @@ def h(s: str) -> str: descr += f"\n{h('level0')}{attr.capitalize()}" unsorted_elements = attribute.items() - sorted_elements = sorted( - unsorted_elements, key=lambda x: _natural_keys(x[0]) - ) + sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) for k, v in sorted_elements: descr += f"{h('empty_line')}" descr_class = v.__class__.__name__ @@ -2100,16 +1996,7 @@ def h(s: str) -> str: else: shape_str = ( "(" - + ", ".join( - [ - ( - str(dim) - if not isinstance(dim, Scalar) - else "" - ) - for dim in v.shape - ] - ) + + ", ".join([(str(dim) if not isinstance(dim, Scalar) else "") for dim in v.shape]) + ")" ) descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}" @@ -2185,9 +2072,7 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: if not self.is_self_contained(): assert self.path is not None - descr += ( - "\nwith the following Dask-backed elements not being self-contained:" - ) + descr += "\nwith the following Dask-backed elements not being self-contained:" description = self.elements_are_self_contained() for _, element_name, element in self.gen_elements(): if not description[element_name]: @@ -2195,9 +2080,7 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: descr += f"\n ▸ {element_name}: {backing_files}" if self.path is not None: - elements_only_in_sdata, elements_only_in_zarr = ( - self._symmetric_difference_with_zarr_store() - ) + elements_only_in_sdata, elements_only_in_zarr = self._symmetric_difference_with_zarr_store() if len(elements_only_in_sdata) > 0: descr += "\nwith the following elements not in the Zarr store:" for element_path in elements_only_in_sdata: @@ -2284,13 +2167,9 @@ def _validate_element_names_are_unique(self) -> None: ValueError If the element names are not unique. """ - check_all_keys_case_insensitively_unique( - [name for _, name, _ in self.gen_elements()], location=() - ) + check_all_keys_case_insensitively_unique([name for _, name, _ in self.gen_elements()], location=()) - def _find_element( - self, element_name: str - ) -> tuple[str, str, SpatialElement | AnnData]: + def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]: """ Retrieve SpatialElement or Table from the SpatialData instance matching element_name. @@ -2389,9 +2268,7 @@ def subset( """ elements_dict: dict[str, SpatialElement] = {} names_tables_to_keep: set[str] = set() - for element_type, element_name, element in self._gen_elements( - include_tables=True - ): + for element_type, element_name, element in self._gen_elements(include_tables=True): if element_name in element_names: if element_type != "tables": elements_dict.setdefault(element_type, {})[element_name] = element @@ -2424,16 +2301,11 @@ def __getitem__(self, item: str) -> SpatialElement | AnnData: def __contains__(self, key: str) -> bool: element_dict = { - element_name: element_value - for _, element_name, element_value in self._gen_elements( - include_tables=True - ) + element_name: element_value for _, element_name, element_value in self._gen_elements(include_tables=True) } return key in element_dict - def get( - self, key: str, default_value: SpatialElement | AnnData | None = None - ) -> SpatialElement | AnnData | None: + def get(self, key: str, default_value: SpatialElement | AnnData | None = None) -> SpatialElement | AnnData | None: """ Get element from SpatialData object based on corresponding name. @@ -2541,9 +2413,7 @@ def filter_by_table_query( obs_names_expr: Predicates | None = None, var_names_expr: Predicates | None = None, layer: str | None = None, - how: Literal[ - "left", "left_exclusive", "inner", "right", "right_exclusive" - ] = "right", + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", ) -> SpatialData: """ Filter the SpatialData object based on a set of table queries. diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index c24c257c..1cb73388 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -33,12 +33,7 @@ def _read_zarr_group_spatialdata_element( read_func: Callable[..., Any], group_name: Literal["images", "labels", "shapes", "points", "tables"], element_type: Literal["image", "labels", "shapes", "points", "tables"], - element_container: ( - dict[str, Raster_T] - | dict[str, DaskDataFrame] - | dict[str, GeoDataFrame] - | dict[str, AnnData] - ), + element_container: (dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData]), on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], ) -> None: with handle_read_errors( @@ -68,9 +63,7 @@ def _read_zarr_group_spatialdata_element( ), ): if element_type in ["image", "labels"]: - reader_format = get_raster_format_for_read( - elem_group, sdata_version - ) + reader_format = get_raster_format_for_read(elem_group, sdata_version) element = read_func( elem_group_path, cast(Literal["image", "labels"], element_type), @@ -124,9 +117,7 @@ def get_raster_format_for_read( def read_zarr( store: str | Path | UPath | zarr.Group, selection: None | tuple[str] = None, - on_bad_files: Literal[ - BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN - ] = BadFileHandleMethod.ERROR, + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, lazy: bool = False, ) -> SpatialData: """ @@ -186,11 +177,7 @@ def read_zarr( shapes: dict[str, GeoDataFrame] = {} tables: dict[str, AnnData] = {} - selector = ( - {"images", "labels", "points", "shapes", "tables"} - if not selection - else set(selection or []) - ) + selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") # we could make this more readable. One can get lost when looking at this dict and iteration over the items @@ -199,10 +186,7 @@ def read_zarr( tuple[ Callable[..., Any], Literal["image", "labels", "shapes", "points", "tables"], - dict[str, Raster_T] - | dict[str, DaskDataFrame] - | dict[str, GeoDataFrame] - | dict[str, AnnData], + dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData], ], ] = { # ome-zarr-py needs a kwargs that has "image" has key. So here we have "image" and not "images" @@ -296,21 +280,15 @@ def _get_groups_for_element( # When writing, use_consolidated must be set to False. Otherwise, the metadata store # can get out of sync with newly added elements (e.g., labels), leading to errors. - root_group = zarr.open_group( - store=resolved_store, mode="r+", use_consolidated=use_consolidated - ) + root_group = zarr.open_group(store=resolved_store, mode="r+", use_consolidated=use_consolidated) element_type_group = root_group.require_group(element_type) - element_type_group = zarr.open_group( - element_type_group.store_path, mode="a", use_consolidated=use_consolidated - ) + element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) element_name_group = element_type_group.require_group(element_name) return root_group, element_type_group, element_name_group -def _group_for_element_exists( - zarr_path: Path, element_type: str, element_name: str -) -> bool: +def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: """ Check if the group for an element exists. diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index fa3920c2..42be9cdd 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -32,19 +32,19 @@ from spatialdata._utils import _check_match_length_channels_c_dim from spatialdata.config import settings from spatialdata.models import C, X, Y, Z, get_axes_names -from spatialdata.models._utils import (DEFAULT_COORDINATE_SYSTEM, TRANSFORM_KEY, MappingToCoordinateSystem_t, - SpatialElement, _validate_mapping_to_coordinate_system_type, - convert_region_column_to_categorical) +from spatialdata.models._utils import ( + DEFAULT_COORDINATE_SYSTEM, + TRANSFORM_KEY, + MappingToCoordinateSystem_t, + SpatialElement, + _validate_mapping_to_coordinate_system_type, + convert_region_column_to_categorical, +) from spatialdata.transformations._utils import _get_transformations, _set_transformations, compute_coordinates from spatialdata.transformations.transformations import BaseTransformation, Identity # Types -Chunks_t: TypeAlias = ( - int - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] -) +Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] ScaleFactors_t = Sequence[dict[str, int] | int] Transform_s = AttrSchema(BaseTransformation, None) @@ -72,9 +72,7 @@ def _is_lazy_anndata(adata: AnnData) -> bool: return False -def _parse_transformations( - element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None -) -> None: +def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: _validate_mapping_to_coordinate_system_type(transformations) transformations_in_element = _get_transformations(element) if ( @@ -180,9 +178,7 @@ def parse( if transformations: transformations = transformations.copy() if "name" in kwargs: - raise ValueError( - "The `name` argument is not (yet) supported for raster data." - ) + raise ValueError("The `name` argument is not (yet) supported for raster data.") # if dims is specified inside the data, get the value of dims from the data if isinstance(data, DataArray): if not isinstance(data.data, DaskArray): # numpy -> dask @@ -230,18 +226,13 @@ def parse( if c_coords is not None: c_coords = _check_match_length_channels_c_dim(data, c_coords, cls.dims.dims) - if ( - c_coords is not None - and len(c_coords) != data.shape[cls.dims.dims.index("c")] - ): + if c_coords is not None and len(c_coords) != data.shape[cls.dims.dims.index("c")]: raise ValueError( f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'" f" with length {data.shape[cls.dims.dims.index('c')]}." ) - data = to_spatial_image( - array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs - ) + data = to_spatial_image(array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs) # parse transformations _parse_transformations(data, transformations) # convert to multiscale if needed @@ -296,18 +287,12 @@ def _(self, data: DataArray) -> None: @validate.register(DataTree) def _(self, data: DataTree) -> None: - for j, k in zip( - data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True - ): + for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True): if j != k: - raise ValueError( - f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`." - ) + raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.") name = {list(data[i].data_vars.keys())[0] for i in data} if len(name) != 1: - raise ValueError( - f"Expected exactly one data variable for the datatree: found `{name}`." - ) + raise ValueError(f"Expected exactly one data variable for the datatree: found `{name}`.") name = list(name)[0] for d in data: super().validate(data[d][name]) @@ -475,14 +460,9 @@ def validate(cls, data: GeoDataFrame) -> None: """ SUGGESTION = " Please use ShapesModel.parse() to construct data that is guaranteed to be valid." if cls.GEOMETRY_KEY not in data: - raise KeyError( - f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." - + SUGGESTION - ) + raise KeyError(f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." + SUGGESTION) if not isinstance(data[cls.GEOMETRY_KEY], GeoSeries): - raise ValueError( - f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION - ) + raise ValueError(f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION) if len(data[cls.GEOMETRY_KEY]) == 0: raise ValueError(f"Column `{cls.GEOMETRY_KEY}` is empty." + SUGGESTION) geom_ = data[cls.GEOMETRY_KEY].values[0] @@ -507,10 +487,7 @@ def validate(cls, data: GeoDataFrame) -> None: "please correct the radii of the circles before calling the parser function.", ) if cls.TRANSFORM_KEY not in data.attrs: - raise ValueError( - f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." - + SUGGESTION - ) + raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION) if len(data) > 0: n = data.geometry.iloc[0]._ndim if n != 2: @@ -607,9 +584,7 @@ def parse(cls, data: Any, **kwargs: Any) -> GeoDataFrame: def _( cls, data: np.ndarray, # type: ignore[type-arg] - geometry: Literal[ - 0, 3, 6 - ], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] + geometry: Literal[0, 3, 6], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] offsets: tuple[ArrayLike, ...] | None = None, radius: float | ArrayLike | None = None, index: ArrayLike | None = None, @@ -620,9 +595,7 @@ def _( geo_df = GeoDataFrame({"geometry": data}) if GeometryType(geometry).name == "POINT": if radius is None: - raise ValueError( - "If `geometry` is `Circles`, `radius` must be provided." - ) + raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -649,9 +622,7 @@ def _( geo_df = GeoDataFrame({"geometry": gc.geoms}) if isinstance(geo_df["geometry"].iloc[0], Point): if radius is None: - raise ValueError( - "If `geometry` is `Circles`, `radius` must be provided." - ) + raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -668,10 +639,7 @@ def _( ) -> GeoDataFrame: if "geometry" not in data.columns: raise ValueError("`geometry` column not found in `GeoDataFrame`.") - if ( - isinstance(data["geometry"].iloc[0], Point) - and cls.RADIUS_KEY not in data.columns - ): + if isinstance(data["geometry"].iloc[0], Point) and cls.RADIUS_KEY not in data.columns: raise ValueError(f"Column `{cls.RADIUS_KEY}` not found.") _parse_transformations(data, transformations) cls.validate(data) @@ -711,8 +679,7 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError(f"Column `{ax}` must be of type `int` or `float`.") if cls.TRANSFORM_KEY not in data.attrs: raise ValueError( - f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." - + SUGGESTION + f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] @@ -795,15 +762,11 @@ def _( if annotation is not None: if feature_key is not None: - df_dict[feature_key] = ( - annotation[feature_key].astype(str).astype("category") - ) + df_dict[feature_key] = annotation[feature_key].astype(str).astype("category") if instance_key is not None: df_dict[instance_key] = annotation[instance_key] if Z not in axes and Z in annotation.columns: - logger.info( - f"Column `{Z}` in `annotation` will be ignored since the data is 2D." - ) + logger.info(f"Column `{Z}` in `annotation` will be ignored since the data is 2D.") for c in set(annotation.columns) - {feature_key, instance_key, X, Y, Z}: df_dict[c] = annotation[c] @@ -842,9 +805,7 @@ def _( if "sort" not in kwargs: index_monotonically_increasing = data.index.is_monotonic_increasing if not isinstance(index_monotonically_increasing, bool): - index_monotonically_increasing = ( - index_monotonically_increasing.compute() - ) + index_monotonically_increasing = index_monotonically_increasing.compute() sort = index_monotonically_increasing else: sort = kwargs["sort"] @@ -882,9 +843,7 @@ def _( if data[feature_key].dtype.name == "category": table[feature_key] = data[feature_key] else: - table[feature_key] = ( - data[feature_key].astype(str).astype("category") - ) + table[feature_key] = data[feature_key].astype(str).astype("category") if instance_key is not None: table[instance_key] = data[instance_key] for c in [X, Y, Z]: @@ -944,13 +903,9 @@ def _add_metadata_and_validate( # It also just changes the state of the series, so it is not a big deal. if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known: try: - data[c] = data[c].cat.set_categories( - data[c].compute().cat.categories - ) + data[c] = data[c].cat.set_categories(data[c].compute().cat.categories) except ValueError: - logger.info( - f"Column `{c}` contains unknown categories. Consider casting it." - ) + logger.info(f"Column `{c}` contains unknown categories. Consider casting it.") _parse_transformations(data, transformations) cls.validate(data) @@ -964,9 +919,7 @@ class TableModel: INSTANCE_KEY = "instance_key" ATTRS_KEY = ATTRS_KEY - def _validate_set_region_key( - self, data: AnnData, region_key: str | None = None - ) -> None: + def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: """ Validate the region key in table.uns or set a new region key as the region key column. @@ -1006,9 +959,7 @@ def _validate_set_region_key( raise ValueError(f"'{region_key}' column not present in table.obs") attrs[self.REGION_KEY_KEY] = region_key - def _validate_set_instance_key( - self, data: AnnData, instance_key: str | None = None - ) -> None: + def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: """ Validate the instance_key in table.uns or set a new instance_key as the instance_key column. @@ -1052,9 +1003,7 @@ def _validate_set_instance_key( if instance_key in data.obs: attrs[self.INSTANCE_KEY] = instance_key else: - raise ValueError( - f"Instance key column '{instance_key}' not found in table.obs." - ) + raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") def _validate_table_annotation_metadata(self, data: AnnData) -> None: """ @@ -1089,26 +1038,16 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: attr = data.uns[ATTRS_KEY] if "region" not in attr: - raise ValueError( - f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION - ) + raise ValueError(f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) if "region_key" not in attr: - raise ValueError( - f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION - ) + raise ValueError(f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) if "instance_key" not in attr: - raise ValueError( - f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION - ) + raise ValueError(f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) if attr[self.REGION_KEY_KEY] not in data.obs: - raise ValueError( - f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column." - ) + raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") if attr[self.INSTANCE_KEY] not in data.obs: - raise ValueError( - f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column." - ) + raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") # Skip detailed dtype/value validation for lazy-loaded AnnData # These checks would trigger data loading, defeating the purpose of lazy loading @@ -1129,27 +1068,17 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: "O", ] and not pd.api.types.is_string_dtype(data.obs[attr[self.INSTANCE_KEY]]) - or ( - dtype == "O" - and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) - is not str - ) + or (dtype == "O" and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) is not str) ): dtype = dtype if dtype != "O" else val_dtype raise TypeError( f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for " f"instance_key column in obs. Dtype found to be {dtype}" ) - expected_regions = ( - attr[self.REGION_KEY] - if isinstance(attr[self.REGION_KEY], list) - else [attr[self.REGION_KEY]] - ) + expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist() if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: - raise ValueError( - f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match." - ) + raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") # Warning for object/string columns with NaN in region_key or instance_key instance_key = attr[self.INSTANCE_KEY] @@ -1161,9 +1090,7 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: if key_value in data.obs: col = data.obs[key_value] col_dtype = col.dtype - if ( - col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype) - ) and col.isna().any(): + if (col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype)) and col.isna().any(): logger.warning( f"The {key_name} column '{key_value}' is of {col_dtype} type and contains NaN values. " "After writing and reading with AnnData, NaN values may (depending on the AnnData version) " @@ -1206,9 +1133,7 @@ def validate( f"using TableModel.parse(adata)." ) # Skip dtype validation for lazy tables (would require loading data) - if not is_lazy and not isinstance( - data.obs[region_key].dtype, CategoricalDtype - ): + if not is_lazy and not isinstance(data.obs[region_key].dtype, CategoricalDtype): raise ValueError( f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." ) @@ -1220,9 +1145,7 @@ def validate( ) # Skip null check for lazy tables (would require loading data) if not is_lazy and data.obs[instance_key].isnull().values.any(): - raise ValueError( - "`table.obs[instance_key]` must not contain null values, but it does." - ) + raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") self._validate_table_annotation_metadata(data) @@ -1259,9 +1182,7 @@ def parse( """ validate_table_attr_keys(adata) # either all live in adata.uns or all be passed in as argument - n_args = sum( - [region is not None, region_key is not None, instance_key is not None] - ) + n_args = sum([region is not None, region_key is not None, instance_key is not None]) if n_args == 0: if cls.ATTRS_KEY not in adata.uns: # table not annotating any element @@ -1290,9 +1211,7 @@ def parse( region = region.tolist() region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): - raise ValueError( - f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values." - ) + raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key @@ -1303,9 +1222,7 @@ def parse( grouped = adata.obs.groupby(region_key, observed=True) grouped_size = grouped.size() grouped_nunique = grouped.nunique() - not_unique = grouped_size[ - grouped_size != grouped_nunique[instance_key] - ].index.tolist() + not_unique = grouped_size[grouped_size != grouped_nunique[instance_key]].index.tolist() if not_unique: raise ValueError( f"Instance key column for region(s) `{', '.join(not_unique)}` does not contain only unique values" @@ -1416,11 +1333,6 @@ def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]: ) annotated_regions = region_key_column.unique().tolist() else: - annotated_regions = ( - table.obs[region_key] - .cat.remove_unused_categories() - .cat.categories.unique() - .tolist() - ) + annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist() assert isinstance(annotated_regions, list) return annotated_regions diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 241336df..80e9b5da 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -106,9 +106,7 @@ def test_shapes( else: # convert each Polygon to a MultiPolygon mixed_multipolygon = shapes["mixed"].assign( - geometry=lambda df: df.geometry.apply( - lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g - ) + geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g) ) assert sdata["mixed"].equals(mixed_multipolygon) assert not sdata["mixed"].equals(shapes["mixed"]) @@ -145,9 +143,7 @@ def test_shapes_geometry_encoding_write_element( # Write each shape element - should use global setting for shape_name in shapes.shapes: - empty_sdata.write_element( - shape_name, sdata_formats=sdata_container_format - ) + empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format) # Verify the encoding metadata in the parquet file parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet" @@ -228,12 +224,8 @@ def test_multiple_tables( tables: list[AnnData], sdata_container_format: SpatialDataContainerFormatType, ) -> None: - sdata_tables = SpatialData( - tables={str(i): tables[i] for i in range(len(tables))} - ) - self._test_table( - tmp_path, sdata_tables, sdata_container_format=sdata_container_format - ) + sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) + self._test_table(tmp_path, sdata_tables, sdata_container_format=sdata_container_format) def test_roundtrip( self, @@ -264,9 +256,7 @@ def test_incremental_io_list_of_elements( assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - shapes.write_element( - ["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format - ) + shapes.write_element(["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format) assert "shapes/new_shapes0" in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" in shapes.elements_paths_on_disk() @@ -381,9 +371,7 @@ def test_incremental_io_on_disk( ValueError, match=match, ): - sdata.write_element( - name, overwrite=True, sdata_formats=sdata_container_format - ) + sdata.write_element(name, overwrite=True, sdata_formats=sdata_container_format) if workaround == 1: new_name = f"{name}_new_place" @@ -414,9 +402,7 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name, sdata_formats=sdata_container_format) - def test_io_and_lazy_loading_points( - self, points, sdata_container_format: SpatialDataContainerFormatType - ): + def test_io_and_lazy_loading_points(self, points, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") points.write(f, sdata_formats=sdata_container_format) @@ -425,9 +411,7 @@ def test_io_and_lazy_loading_points( sdata2 = SpatialData.read(f) assert len(get_dask_backing_files(sdata2)) > 0 - def test_io_and_lazy_loading_raster( - self, images, labels, sdata_container_format: SpatialDataContainerFormatType - ): + def test_io_and_lazy_loading_raster(self, images, labels, sdata_container_format: SpatialDataContainerFormatType): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -477,13 +461,9 @@ def test_replace_transformation_on_disk_non_raster( with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") sdata.write(f, sdata_formats=sdata_container_format) - t0 = get_transformation( - SpatialData.read(f).__getattribute__(k)[elem_name] - ) + t0 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) assert isinstance(t0, Identity) - set_transformation( - sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata - ) + set_transformation(sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata) t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) @@ -494,16 +474,10 @@ def test_write_overwrite_fails_when_no_zarr_store( f = Path(tmpdir) / "data.zarr" f.mkdir() old_data = SpatialData() - with pytest.raises( - ValueError, match="The target file path specified already exists" - ): + with pytest.raises(ValueError, match="The target file path specified already exists"): old_data.write(f, sdata_formats=sdata_container_format) - with pytest.raises( - ValueError, match="The target file path specified already exists" - ): - full_sdata.write( - f, overwrite=True, sdata_formats=sdata_container_format - ) + with pytest.raises(ValueError, match="The target file path specified already exists"): + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( self, @@ -536,9 +510,7 @@ def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( match=r"Details: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object", ): - full_sdata.write( - f, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) def test_overwrite_fails_when_zarr_store_present( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType @@ -558,9 +530,7 @@ def test_overwrite_fails_when_zarr_store_present( ValueError, match=r"Details: the target path either contains, coincides or is contained in the current Zarr store", ): - full_sdata.write( - f, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) # support for overwriting backed sdata has been temporarily removed # with tempfile.TemporaryDirectory() as tmpdir: @@ -581,7 +551,9 @@ def test_overwrite_fails_when_zarr_store_present( def test_overwrite_fails_onto_non_zarr_file( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType ): - ERROR_MESSAGE = "The target file path specified already exists, and it has been detected to not be a Zarr store." + ERROR_MESSAGE = ( + "The target file path specified already exists, and it has been detected to not be a Zarr store." + ) with tempfile.TemporaryDirectory() as tmpdir: f0 = os.path.join(tmpdir, "test.txt") with open(f0, "w"): @@ -594,17 +566,13 @@ def test_overwrite_fails_onto_non_zarr_file( ValueError, match=ERROR_MESSAGE, ): - full_sdata.write( - f0, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f0, overwrite=True, sdata_formats=sdata_container_format) f1 = os.path.join(tmpdir, "test.zarr") os.mkdir(f1) with pytest.raises(ValueError, match=ERROR_MESSAGE): full_sdata.write(f1, sdata_formats=sdata_container_format) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write( - f1, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f1, overwrite=True, sdata_formats=sdata_container_format) def test_incremental_io_in_memory( @@ -642,9 +610,7 @@ def test_bug_rechunking_after_queried_raster(): # https://github.com/scverse/spatialdata-io/issues/117 ## single_scale = Image2DModel.parse(RNG.random((100, 10, 10)), chunks=(5, 5, 5)) - multi_scale = Image2DModel.parse( - RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5) - ) + multi_scale = Image2DModel.parse(RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5)) images = {"single_scale": single_scale, "multi_scale": multi_scale} sdata = SpatialData(images=images) queried = sdata.query.bounding_box( @@ -659,9 +625,7 @@ def test_bug_rechunking_after_queried_raster(): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_self_contained( - full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType -) -> None: +def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained assert full_sdata.is_self_contained() description = full_sdata.elements_are_self_contained() @@ -685,10 +649,7 @@ def test_self_contained( # because of the images, labels and points description = sdata2.elements_are_self_contained() for element_name, self_contained in description.items(): - if any( - element_name.startswith(prefix) - for prefix in ["image", "labels", "points"] - ): + if any(element_name.startswith(prefix) for prefix in ["image", "labels", "points"]): assert not self_contained else: assert self_contained @@ -721,11 +682,7 @@ def test_self_contained( assert not sdata2.is_self_contained() description = sdata2.elements_are_self_contained() assert description["combined"] is False - assert all( - description[element_name] - for element_name in description - if element_name != "combined" - ) + assert all(description[element_name] for element_name in description if element_name != "combined") @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) @@ -737,9 +694,7 @@ def test_symmetric_difference_with_zarr_store( full_sdata.write(f, sdata_formats=sdata_container_format) # the list of element on-disk and in-memory is the same - only_in_memory, only_on_disk = ( - full_sdata._symmetric_difference_with_zarr_store() - ) + only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() assert len(only_in_memory) == 0 assert len(only_on_disk) == 0 @@ -755,9 +710,7 @@ def test_symmetric_difference_with_zarr_store( del full_sdata.tables["table"] # now the list of element on-disk and in-memory is different - only_in_memory, only_on_disk = ( - full_sdata._symmetric_difference_with_zarr_store() - ) + only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() assert set(only_in_memory) == { "images/new_image2d", "labels/new_labels2d", @@ -775,17 +728,13 @@ def test_symmetric_difference_with_zarr_store( @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_change_path_of_subset( - full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType -) -> None: +def test_change_path_of_subset(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: """A subset SpatialData object has not Zarr path associated, show that we can reassign the path""" with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f, sdata_formats=sdata_container_format) - subset = full_sdata.subset( - ["image2d", "labels2d", "points_0", "circles", "table"] - ) + subset = full_sdata.subset(["image2d", "labels2d", "points_0", "circles", "table"]) assert subset.path is None subset.path = Path(f) @@ -850,9 +799,7 @@ def test_incremental_io_valid_name(full_sdata: SpatialData) -> None: @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_incremental_io_attrs( - points: SpatialData, sdata_container_format: SpatialDataContainerFormatType -) -> None: +def test_incremental_io_attrs(points: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") my_attrs = {"a": "b", "c": 1} @@ -879,9 +826,7 @@ def test_incremental_io_attrs( cached_sdata_blobs = blobs() -@pytest.mark.parametrize( - "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] -) +@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_delete_element_from_disk( full_sdata, @@ -889,9 +834,7 @@ def test_delete_element_from_disk( sdata_container_format: SpatialDataContainerFormatType, ) -> None: # can't delete an element for a SpatialData object without associated Zarr store - with pytest.raises( - ValueError, match="The SpatialData object is not backed by a Zarr store." - ): + with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): full_sdata.delete_element_from_disk("image2d") with tempfile.TemporaryDirectory() as tmpdir: @@ -919,9 +862,7 @@ def test_delete_element_from_disk( # can delete an element present both in-memory and on-disk full_sdata.delete_element_from_disk(element_name) - only_in_memory, only_on_disk = ( - full_sdata._symmetric_difference_with_zarr_store() - ) + only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() element_type = full_sdata._element_type_from_element_name(element_name) element_path = f"{element_type}/{element_name}" assert element_path in only_in_memory @@ -936,9 +877,7 @@ def test_delete_element_from_disk( assert element_path not in on_disk -@pytest.mark.parametrize( - "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] -) +@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_element_already_on_disk_different_type( full_sdata, @@ -992,9 +931,7 @@ def test_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( - iter(_get_shapes().values()) - ) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") with pytest.raises(ValueError, match="Name (must|cannot)"): @@ -1005,9 +942,7 @@ def test_writing_valid_table_name_invalid_table(tmp_path: Path): # also try with a valid table name but invalid table # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() - invalid_sdata.tables.data["valid_name"] = AnnData( - np.array([[0]]), layers={"invalid name": np.array([[0]])} - ) + invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write(tmp_path / "data.zarr") @@ -1020,9 +955,7 @@ def test_incremental_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( - iter(_get_shapes().values()) - ) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") for element_type in ["images", "labels", "points", "shapes", "tables"]: @@ -1037,9 +970,7 @@ def test_incremental_writing_valid_table_name_invalid_table(tmp_path: Path): # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() invalid_sdata.write(tmp_path / "data2.zarr") - invalid_sdata.tables.data["valid_name"] = AnnData( - np.array([[0]]), layers={"invalid name": np.array([[0]])} - ) + invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write_element("valid_name") @@ -1059,19 +990,13 @@ def test_reading_invalid_name(tmp_path: Path): ) valid_sdata.write(tmp_path / "data.zarr") # Circumvent validation at construction time and check validation happens again at writing time. - (tmp_path / "data.zarr/points" / points_name).rename( - tmp_path / "data.zarr/points" / "has whitespace" - ) + (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") # This one is not allowed on windows - (tmp_path / "data.zarr/shapes" / shapes_name).rename( - tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@" - ) + (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@") # We do this as the key of the element is otherwise not in the consolidated metadata, leading to an error. valid_sdata.write_consolidated_metadata() - with pytest.raises( - ValidationError, match="Cannot construct SpatialData" - ) as exc_info: + with pytest.raises(ValidationError, match="Cannot construct SpatialData") as exc_info: read_zarr(tmp_path / "data.zarr") actual_message = str(exc_info.value) @@ -1084,14 +1009,10 @@ def test_reading_invalid_name(tmp_path: Path): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_write_store_unconsolidated_and_read( - full_sdata, sdata_container_format: SpatialDataContainerFormatType -): +def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" - full_sdata.write( - path, consolidate_metadata=False, sdata_formats=sdata_container_format - ) + full_sdata.write(path, consolidate_metadata=False, sdata_formats=sdata_container_format) group = zarr.open_group(path, mode="r") assert group.metadata.consolidated_metadata is None @@ -1100,9 +1021,7 @@ def test_write_store_unconsolidated_and_read( @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_can_read_sdata_with_reconsolidation( - full_sdata, sdata_container_format: SpatialDataContainerFormatType -): +def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" full_sdata.write(path, sdata_formats=sdata_container_format) From 120d11dc7386875f8d36f87f8eec1963b5e097b6 Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:05:03 +0100 Subject: [PATCH 3/3] fix: address pre-commit linting issues - Simplify if/return pattern in _is_lazy_anndata (SIM103) - Add missing TableModel import in test fixture (F821) - Use modern np.random.Generator instead of np.random.rand (NPY002) --- src/spatialdata/models/models.py | 4 +--- tests/io/test_readwrite.py | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 42be9cdd..36a3a973 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -67,9 +67,7 @@ def _is_lazy_anndata(adata: AnnData) -> bool: True if the AnnData is lazily loaded, False otherwise. """ # Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D) - if not isinstance(adata.obs, pd.DataFrame): - return True - return False + return not isinstance(adata.obs, pd.DataFrame) def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 80e9b5da..4973a4d7 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1124,9 +1124,12 @@ class TestLazyTableLoading: @pytest.fixture def sdata_with_table(self) -> SpatialData: """Create a SpatialData object with a simple table for testing.""" + from spatialdata.models import TableModel + + rng = default_rng(42) table = TableModel.parse( AnnData( - X=np.random.rand(100, 50), + X=rng.random((100, 50)), obs=pd.DataFrame( { "region": pd.Categorical(["region1"] * 100),