diff --git a/eitprocessing/roi_selection/gridselection.py b/eitprocessing/roi_selection/gridselection.py index f01e2a12b..f77028c49 100644 --- a/eitprocessing/roi_selection/gridselection.py +++ b/eitprocessing/roi_selection/gridselection.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from dataclasses import field from typing import Literal +from typing import get_type_hints import numpy as np from numpy.typing import NDArray from . import ROISelection @@ -40,8 +41,8 @@ class GridSelection(ROISelection): `split_columns` has the same effect on columns as `split_rows` has on rows. - Regions are ordered according to C indexing order. The `matrix_layout()` method provides a map - showing how the regions are ordered. + Regions are ordered according to C indexing order. The `matrix_layout` attribute provides a map + showing how these regions are ordered. Common grids are pre-defined: - VentralAndDorsal: vertically divided into ventral and dorsal; @@ -57,37 +58,51 @@ class GridSelection(ROISelection): split_columns: Allows columns to be split over two regions. Examples: - >>> pixel_map = array([[ 1, 2, 3], - [ 4, 5, 6], - [ 7, 8, 9], - [10, 11, 12], - [13, 14, 15], - [16, 17, 18]]) - >>> gs = GridSelection(3, 1, split_pixels=False) - >>> matrices = gs.find_grid(pixel_map) - >>> matrices[0] * pixel_map + >>> pixel_map = np.array([[ 1, 2, 3], + [ 4, 5, 6], + [ 7, 8, 9], + [10, 11, 12], + [13, 14, 15], + [16, 17, 18]]) + >>> gs = GridSelection(3, 1, split_rows=False) + >>> rois = gs.find_grid(pixel_map) + >>> gs.matrix_layout + array([[0], + [1], + [2]]) + >>> rois[0] * pixel_map array([[1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]) - >>> gs.matrix_layout() - array([[0], - [1], - [2]]) - >>> gs2 = GridSelection(2, 2, split_pixels=True) - >>> matrices2 = gs.find_grid(pixel_map) - >>> gs2.matrix_layout() + >>> rois[1] * pixel_map + array([[0, 0, 0], + [0, 0, 0], + [7, 8, 9], + [10, 11, 12], + [0, 0, 0], + [0, 0, 0]]) + >>> gs2 = GridSelection(2, 2, split_columns=True) + >>> rois2 = gs.find_grid(pixel_map) + >>> gs2.matrix_layout array([[0, 1], [2, 3]]) - >>> matrices2[2] + >>> rois2[2] array([[0. , 0. , 0. ], [0. , 0. , 0. ], [0. , 0. , 0. ], [1. , 0.5, 0. ], [1. , 0.5, 0. ], [1. , 0.5, 0. ]]) + >>> rois2[3] + array([[0. , 0. , 0. ], + [0. , 0. , 0. ], + [0. , 0. , 0. ], + [0. , 0.5, 1. ], + [0. , 0.5, 1. ], + [0. , 0.5, 1. ]]) """ v_split: int @@ -97,29 +112,24 @@ class GridSelection(ROISelection): ignore_nan_rows: bool = True ignore_nan_columns: bool = True - def _check_attribute_type(self, name, type_): - """Checks whether an attribute is an instance of the given type.""" - attr = getattr(self, name) - if not isinstance(attr, type_): - message = f"Invalid type for `{name}`." - message += f"Should be {type_}, not {type(attr)}." - raise TypeError(message) - def __post_init__(self): - self._check_attribute_type("v_split", int) - self._check_attribute_type("h_split", int) + try: + if self.v_split == int(self.v_split): + self.v_split = int(self.v_split) + if self.h_split == int(self.h_split): + self.h_split = int(self.h_split) + finally: + for attr, type_ in get_type_hints(self).items(): + if not isinstance(getattr(self, attr), type_): + raise TypeError( + f"Invalid type for `{attr}`. Should be {type_}, not {type(attr)}." + ) if self.v_split < 1: raise InvalidVerticalDivision("`v_split` can't be smaller than 1.") - if self.h_split < 1: raise InvalidHorizontalDivision("`h_split` can't be smaller than 1.") - self._check_attribute_type("split_columns", bool) - self._check_attribute_type("split_rows", bool) - self._check_attribute_type("ignore_nan_columns", bool) - self._check_attribute_type("ignore_nan_rows", bool) - def find_grid(self, data: NDArray) -> list[NDArray]: """ Create 2D arrays to split a grid into regions. @@ -306,11 +316,15 @@ def _create_grouping_vector_split_pixels( # pylint: disable=too-many-locals return final - def matrix_layout(self) -> NDArray: + @property + def _matrix_layout(self) -> NDArray: """Returns a 2D array showing the layout of the matrices returned by `find_grid`.""" n_regions = self.v_split * self.h_split return np.reshape(np.arange(n_regions), (self.v_split, self.h_split)) + @_matrix_layout.getter # private attribute with getter avoids users overriding this property + def matrix_layout(self): + return self._matrix_layout class InvalidDivision(Exception): diff --git a/tests/test_gridselection.py b/tests/test_gridselection.py index 24f98bdb4..21b5a866b 100644 --- a/tests/test_gridselection.py +++ b/tests/test_gridselection.py @@ -465,7 +465,7 @@ def test_split_pixels_nans(data_string, split_vh, result): ) def test_matrix_layout(split_vh: tuple[int, int], result: list[list[int]]): """ - Test `matrix_layout()` method. + Test `matrix_layout` method. Args: split_vh (tuple[int, int]): `v_split` and `h_split`. @@ -474,6 +474,5 @@ def test_matrix_layout(split_vh: tuple[int, int], result: list[list[int]]): """ gs = GridSelection(*split_vh) - layout = gs.matrix_layout() - assert np.array_equal(layout, np.array(result)) + assert np.array_equal(gs.matrix_layout, np.array(result))