Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 125 additions & 11 deletions eitprocessing/roi_selection/gridselection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@
import warnings
from dataclasses import dataclass
from dataclasses import field
from enum import auto
from typing import Literal
import numpy as np
from numpy.typing import NDArray
from strenum import LowercaseStrEnum
from . import ROISelection


class DivisionMethod(LowercaseStrEnum):
geometrical = auto()
geometrical_split_pixels = auto()
physiological = auto()


# DivisionMethod.geometrical_split_pixels == "geometrical_split_pixels"

@dataclass
class GridSelection(ROISelection):
"""Create regions of interest by division into a grid.

GridSelection allows for the creation a list of 2D arrays that can be used to divide a two- or
GridSelection allows for the creation of a list of 2D arrays that can be used to divide a two- or
higher-dimensional array into several regions structured in a grid. An instance of
GridSelection contains information about how to subdivide an input matrix. Calling
`find_grid(data)`, where data is a 2D array, results in a list of arrays with the same
Expand All @@ -26,18 +36,22 @@ class GridSelection(ROISelection):
only contain NaN are split as if it is a (28, 28) array. The resulting arrays have the shape
(32, 32) with the same cells as the input data containing NaN values.

If the number of rows or columns can not split evenly, a row or column can be split among two
If the number of rows or columns can not be split evenly, a row or column can be split among two
regions. This behaviour is controlled by `split_rows` and `split_columns`.

If `split_rows` is `False` (default), rows will not be split between two groups. A warning will
If `split_rows` is `geometrical` (default), rows will not be split between two groups. A warning will
be shown stating regions don't contain equal numbers of rows. The regions towards the top will
be larger. E.g., when a (5, 2) array is split in two vertical regions, the first region will
contain the first three rows, and the second region the last two rows.

If `split_rows` is `True`, e.g. a (5, 2) array that is split in two vertical regions, the first
region will contain the first two rows and half of each pixel of the third row. The second
If `split_rows` is `geometrical_split_pixels`, e.g. a (5, 2) array that is split in two vertical regions,
the first region will contain the first two rows and half of each pixel of the third row. The second
region contains half of each pixel in the third row, and the last two rows.

If `split_rows` is `physiological`, an array is split so that the cumulative sum of the rows in each region
is exactly equal. For instance, if an array is split in two regions, a pixel row can contribute for 20% to
the first region and for 80% to the second region.

`split_columns` has the same effect on columns as `split_rows` has on rows.

Regions are ordered according to C indexing order. The `matrix_layout()` method provides a map
Expand Down Expand Up @@ -92,8 +106,8 @@ class GridSelection(ROISelection):

v_split: int
h_split: int
split_rows: bool = False
split_columns: bool = False
split_rows: DivisionMethod = DivisionMethod.geometrical
split_columns: DivisionMethod = DivisionMethod.geometrical
ignore_nan_rows: bool = True
ignore_nan_columns: bool = True

Expand All @@ -115,8 +129,8 @@ def __post_init__(self):
if self.h_split < 1:
raise InvalidHorizontalDivision("`h_split` can't be smaller than 1.")

self._check_attribute_type("split_columns", bool)
self._check_attribute_type("split_rows", bool)
self._check_attribute_type("split_columns", DivisionMethod)
self._check_attribute_type("split_rows", DivisionMethod)
self._check_attribute_type("ignore_nan_columns", bool)
self._check_attribute_type("ignore_nan_rows", bool)

Expand All @@ -136,8 +150,9 @@ def find_grid(self, data: NDArray) -> list[NDArray]:
h_split`.
"""
grouping_method = {
True: self._create_grouping_vector_split_pixels,
False: self._create_grouping_vector_no_split_pixels,
DivisionMethod.geometrical: self._create_grouping_vector_no_split_pixels,
DivisionMethod.geometrical_split_pixels: self._create_grouping_vector_split_pixels,
DivisionMethod.physiological: self._create_grouping_vector_physiological
}

horizontal_grouping_vectors = grouping_method[self.split_columns](
Expand Down Expand Up @@ -306,6 +321,105 @@ def _create_grouping_vector_split_pixels( # pylint: disable=too-many-locals

return final

def _create_grouping_vector_physiological( # pylint: disable=too-many-locals
self,
matrix: NDArray,
horizontal: bool,
n_groups: int,
) -> list[NDArray]:

"""Create a grouping vector to split vector into `n` groups allowing
split elements."""

axis = 0 if horizontal else 1

# create a vector that is nan if the entire column/row is nan, 1 otherwise
vector_is_nan = np.all(np.isnan(matrix), axis=axis)
vector = np.ones(vector_is_nan.shape)

if (horizontal and self.ignore_nan_columns) or (
not horizontal and self.ignore_nan_rows
):
vector[vector_is_nan] = np.nan

# remove non-numeric (nan) elements at vector ends
# nan elements between numeric elements are kept
numeric_element_indices = np.argwhere(~np.isnan(vector))
first_num_element = numeric_element_indices.min()
last_num_element = numeric_element_indices.max()
else:
first_num_element = 0
last_num_element = len(vector) - 1

n_elements = last_num_element - first_num_element + 1

group_size = n_elements / n_groups

if group_size < 1:
if horizontal:
warnings.warn(
f"The number horizontal regions ({n_groups}) is larger than the "
f"number of available columns ({n_elements}).",
MoreHorizontalGroupsThanColumns,
)
else:
warnings.warn(
f"The number vertical regions ({n_groups}) is larger than the "
f"number of available rows ({n_elements}).",
MoreVerticalGroupsThanRows,
)

sum_along_axis = np.nansum(matrix, axis=axis)
relative_sum_along_axis = sum_along_axis / np.nansum(matrix)
relative_cumsum_along_axis = np.cumsum(relative_sum_along_axis)

lower_bounds = np.arange(n_groups) / n_groups
upper_bounds = (np.arange(n_groups) + 1) / n_groups

# Otherwise the first row will not fall in the first region (because they are 0)
# and last rows will not fall in the last region, because they reach 1.0
lower_bounds[0] = -np.inf
upper_bounds[-1] = np.inf

row_in_region = []

for lower_bound, upper_bound in zip(lower_bounds, upper_bounds):
row_in_region.append(
np.logical_and(
relative_cumsum_along_axis > lower_bound, relative_cumsum_along_axis <= upper_bound
)
)

row_in_region = np.array(row_in_region).T
final = row_in_region.astype(float)

# find initial region for each row
initial_regions = np.apply_along_axis(np.flatnonzero, 1, row_in_region).flatten()

# find transitions between regions
region_borders = np.flatnonzero(np.diff(initial_regions))

# finds overlap in transition region
for previous_region, (ventral_row, upper_bound) in enumerate(
zip(region_borders, upper_bounds)
):
dorsal_row = ventral_row + 1
next_region = previous_region + 1
a, b = relative_cumsum_along_axis[ventral_row], relative_cumsum_along_axis[dorsal_row]
diff = b - a
to_a = upper_bound - a
fraction_to_a = to_a / diff
fraction_to_b = 1 - fraction_to_a

final[dorsal_row, previous_region] = fraction_to_a
final[dorsal_row, next_region] = fraction_to_b
final = final.T
final = final * vector
# convert to list of vectors
final = [final[n, :] for n in range(final.shape[0])]

return final

def matrix_layout(self) -> NDArray:
"""Returns a 2D array showing the layout of the matrices returned by
`find_grid`."""
Expand Down
Loading