diff --git a/eitprocessing/roi_selection/__init__.py b/eitprocessing/roi_selection/__init__.py new file mode 100644 index 000000000..245a0f571 --- /dev/null +++ b/eitprocessing/roi_selection/__init__.py @@ -0,0 +1,7 @@ +class ROISelection: + def minimal_cluster( + self, + pixel_map, + ): + # TODO create minimal cluster size selector + raise NotImplementedError() diff --git a/eitprocessing/roi_selection/gridselection.py b/eitprocessing/roi_selection/gridselection.py new file mode 100644 index 000000000..f01e2a12b --- /dev/null +++ b/eitprocessing/roi_selection/gridselection.py @@ -0,0 +1,390 @@ +import bisect +import itertools +import warnings +from dataclasses import dataclass +from dataclasses import field +from typing import Literal +import numpy as np +from numpy.typing import NDArray +from . import ROISelection + + +@dataclass +class GridSelection(ROISelection): + """Create regions of interest by division into a grid. + + GridSelection allows for the creation a list of 2D arrays that can be used to divide a two- or + higher-dimensional array into several regions structured in a grid. An instance of + GridSelection contains information about how to subdivide an input matrix. Calling + `find_grid(data)`, where data is a 2D array, results in a list of arrays with the same + dimension as `data`, each representing a single region. Each resulting 2D array contains the + value 0 for pixels that do not belong to the region, and the value 1 or any number between 0 + and 1 for pixels that (partly) belong to the region. + + Rows and columns at the edges of `data` that only contain NaN (not a number) values are + ignored. E.g. a (32, 32) array where the first and last two rows and first and last two columns + only contain NaN are split as if it is a (28, 28) array. The resulting arrays have the shape + (32, 32) with the same cells as the input data containing NaN values. + + If the number of rows or columns can not split evenly, a row or column can be split among two + regions. This behaviour is controlled by `split_rows` and `split_columns`. + + If `split_rows` is `False` (default), rows will not be split between two groups. A warning will + be shown stating regions don't contain equal numbers of rows. The regions towards the top will + be larger. E.g., when a (5, 2) array is split in two vertical regions, the first region will + contain the first three rows, and the second region the last two rows. + + If `split_rows` is `True`, e.g. a (5, 2) array that is split in two vertical regions, the first + region will contain the first two rows and half of each pixel of the third row. The second + region contains half of each pixel in the third row, and the last two rows. + + `split_columns` has the same effect on columns as `split_rows` has on rows. + + Regions are ordered according to C indexing order. The `matrix_layout()` method provides a map + showing how the regions are ordered. + + Common grids are pre-defined: + - VentralAndDorsal: vertically divided into ventral and dorsal; + - RightAndLeft: horizontally divided into anatomical right and left; NB: anatomical right is + the left side of the matrix; + - FourLayers: vertically divided into ventral, mid-ventral, mid-dorsal and dorsal; + - Quadrants: vertically and horizontally divided into four quadrants. + + Args: + v_split: The number of vertical regions. Must be 1 or larger. + h_split: The number of horizontal regions. Must be 1 or larger. + split_rows: Allows rows to be split over two regions. + split_columns: Allows columns to be split over two regions. + + Examples: + >>> pixel_map = array([[ 1, 2, 3], + [ 4, 5, 6], + [ 7, 8, 9], + [10, 11, 12], + [13, 14, 15], + [16, 17, 18]]) + >>> gs = GridSelection(3, 1, split_pixels=False) + >>> matrices = gs.find_grid(pixel_map) + >>> matrices[0] * pixel_map + array([[1, 2, 3], + [4, 5, 6], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]) + >>> gs.matrix_layout() + array([[0], + [1], + [2]]) + >>> gs2 = GridSelection(2, 2, split_pixels=True) + >>> matrices2 = gs.find_grid(pixel_map) + >>> gs2.matrix_layout() + array([[0, 1], + [2, 3]]) + >>> matrices2[2] + array([[0. , 0. , 0. ], + [0. , 0. , 0. ], + [0. , 0. , 0. ], + [1. , 0.5, 0. ], + [1. , 0.5, 0. ], + [1. , 0.5, 0. ]]) + """ + + v_split: int + h_split: int + split_rows: bool = False + split_columns: bool = False + ignore_nan_rows: bool = True + ignore_nan_columns: bool = True + + def _check_attribute_type(self, name, type_): + """Checks whether an attribute is an instance of the given type.""" + attr = getattr(self, name) + if not isinstance(attr, type_): + message = f"Invalid type for `{name}`." + message += f"Should be {type_}, not {type(attr)}." + raise TypeError(message) + + def __post_init__(self): + self._check_attribute_type("v_split", int) + self._check_attribute_type("h_split", int) + + if self.v_split < 1: + raise InvalidVerticalDivision("`v_split` can't be smaller than 1.") + + if self.h_split < 1: + raise InvalidHorizontalDivision("`h_split` can't be smaller than 1.") + + self._check_attribute_type("split_columns", bool) + self._check_attribute_type("split_rows", bool) + self._check_attribute_type("ignore_nan_columns", bool) + self._check_attribute_type("ignore_nan_rows", bool) + + def find_grid(self, data: NDArray) -> list[NDArray]: + """ + Create 2D arrays to split a grid into regions. + + Create 2D arrays to split the given data into regions. The number of 2D + arrays will equal the number regions, which is the multiplicaiton of + `v_split` and `h_split`. + + Args: + data (NDArray): a 2D array containing any numeric or np.nan data. + + Returns: + list[NDArray]: a list of `n` 2D arrays where `n` is `v_split * + h_split`. + """ + grouping_method = { + True: self._create_grouping_vector_split_pixels, + False: self._create_grouping_vector_no_split_pixels, + } + + horizontal_grouping_vectors = grouping_method[self.split_columns]( + data, horizontal=True, n_groups=self.h_split + ) + + vertical_grouping_vectors = grouping_method[self.split_rows]( + data, horizontal=False, n_groups=self.v_split + ) + + matrices = [] + for vertical, horizontal in itertools.product( + vertical_grouping_vectors, horizontal_grouping_vectors + ): + matrix = np.outer(vertical, horizontal) + matrix[np.isnan(data)] = np.nan + matrices.append(matrix) + + return matrices + + def _create_grouping_vector_no_split_pixels( # pylint: disable=too-many-locals + self, + data: NDArray, + horizontal: bool, + n_groups: int, + ) -> list[NDArray]: + """Create a grouping vector to split vector into `n` groups not + allowing split elements.""" + + axis = 0 if horizontal else 1 + + if (horizontal and self.ignore_nan_columns) or ( + not horizontal and self.ignore_nan_rows + ): + is_numeric = ~np.isnan(data) + numeric_vector_indices = np.argwhere(is_numeric.sum(axis) > 0) + first_numeric_vector = numeric_vector_indices.min() + last_vector_numeric = numeric_vector_indices.max() + else: + first_numeric_vector = 0 + last_vector_numeric = data.shape[1 - axis] - 1 + + n_vectors = last_vector_numeric - first_numeric_vector + 1 + + if n_groups > n_vectors: + if horizontal: # pylint: disable=no-else-raise + raise InvalidHorizontalDivision( + "The number horizontal regions is larger than the " + f"number of available columns ({n_vectors})." + ) + else: + raise InvalidVerticalDivision( + "The number vertical regions is larger than the " + f"number of available rows ({n_vectors})." + ) + + n_vectors_per_region = n_vectors / n_groups + + if n_vectors_per_region % 1 > 0: + if horizontal: + warnings.warn( + "The horizontal regions will not have an equal number of " + f"columns. {n_vectors} is not equally divisible by {n_groups}.", + UnevenHorizontalDivision, + ) + else: + warnings.warn( + "The vertical regions will not have an equal number of " + f"columns. {n_vectors} is not equally divisible by {n_groups}.", + UnevenVerticalDivision, + ) + + region_boundaries = [ + first_numeric_vector + + bisect.bisect_left(np.arange(n_vectors) / n_vectors_per_region, c) + for c in range(n_groups + 1) + ] + + vectors = [] + for start, end in itertools.pairwise(region_boundaries): + vector = np.ones(data.shape[1 - axis]) + vector[:start] = 0.0 + vector[end:] = 0.0 + vectors.append(vector) + + return vectors + + def _create_grouping_vector_split_pixels( # pylint: disable=too-many-locals + self, + matrix: NDArray, + horizontal: bool, + n_groups: int, + ) -> list[NDArray]: + """Create a grouping vector to split vector into `n` groups allowing + split elements.""" + + axis = 0 if horizontal else 1 + + # create a vector that is nan if the entire column/row is nan, 1 otherwise + vector_is_nan = np.all(np.isnan(matrix), axis=axis) + vector = np.ones(vector_is_nan.shape) + + if (horizontal and self.ignore_nan_columns) or ( + not horizontal and self.ignore_nan_rows + ): + vector[vector_is_nan] = np.nan + + # remove non-numeric (nan) elements at vector ends + # nan elements between numeric elements are kept + numeric_element_indices = np.argwhere(~np.isnan(vector)) + first_num_element = numeric_element_indices.min() + last_num_element = numeric_element_indices.max() + else: + first_num_element = 0 + last_num_element = len(vector) - 1 + + n_elements = last_num_element - first_num_element + 1 + + group_size = n_elements / n_groups + + if group_size < 1: + if horizontal: + warnings.warn( + f"The number horizontal regions ({n_groups}) is larger than the " + f"number of available columns ({n_elements}).", + MoreHorizontalGroupsThanColumns, + ) + else: + warnings.warn( + f"The number vertical regions ({n_groups}) is larger than the " + f"number of available rows ({n_elements}).", + MoreVerticalGroupsThanRows, + ) + + # find the right boundaries (upper values) of each group + right_boundaries = (np.arange(n_groups) + 1) * group_size + right_boundaries = right_boundaries[:, np.newaxis] # converts to row vector + + # each row in the base represents one group + base = np.tile(np.arange(n_elements), (n_groups, 1)) + + # if the element number is higher than the split, it does not belong in this group + element_contribution_to_group = right_boundaries - base + element_contribution_to_group[element_contribution_to_group < 0] = 0 + + # if the element to the right is a full group size, this element is ruled out + rule_out = element_contribution_to_group[:, 1:] >= group_size + element_contribution_to_group[:, :-1][rule_out] = 0 + + # elements have a maximum value of 1 + element_contribution_to_group = np.fmin(element_contribution_to_group, 1) + + # if this element is already represented in the previous group (row), subtract that + element_contribution_to_group[1:] -= element_contribution_to_group[:-1] + element_contribution_to_group[element_contribution_to_group < 0] = 0 + + # element_contribution_to_group only represents non-nan elements + # insert into final including non-nan elements + final = np.full((n_groups, len(vector)), np.nan) + final[ + :, first_num_element : last_num_element + 1 + ] = element_contribution_to_group + + # convert to list of vectors + final = [final[n, :] for n in range(final.shape[0])] + + return final + + def matrix_layout(self) -> NDArray: + """Returns a 2D array showing the layout of the matrices returned by + `find_grid`.""" + n_regions = self.v_split * self.h_split + return np.reshape(np.arange(n_regions), (self.v_split, self.h_split)) + + +class InvalidDivision(Exception): + """Raised when the data can't be divided into regions.""" + + +class InvalidHorizontalDivision(InvalidDivision): + """Raised when the data can't be divided into horizontal regions.""" + + +class InvalidVerticalDivision(InvalidDivision): + """Raised when the data can't be divided into vertical regions.""" + + +class DivisionWarning(Warning): + pass + + +class UnevenDivision(DivisionWarning): + """Warning for when a grid selection results in groups of uneven size.""" + + +class UnevenHorizontalDivision(UnevenDivision): + """Warning for when a grid selection results in horizontal groups of uneven size.""" + + +class UnevenVerticalDivision(UnevenDivision): + """Warning for when a grid selection results in vertical groups of uneven size.""" + + +class MoreGroupsThanVectors(DivisionWarning): + """Warning for when the groups outnumber the available vectors.""" + + +class MoreVerticalGroupsThanRows(MoreGroupsThanVectors): + """Warning for when the vertical groups outnumber the available rows.""" + + +class MoreHorizontalGroupsThanColumns(MoreGroupsThanVectors): + """Warning for when the horizontal groups outnumber the available rows.""" + + +@dataclass +class VentralAndDorsal(GridSelection): + """Split data into a ventral and dorsal region of interest.""" + + v_split: Literal[2] = field(default=2, init=False) + h_split: Literal[1] = field(default=1, init=False) + split_rows = True + + +@dataclass +class RightAndLeft(GridSelection): + """Split data into a right and left region of interest.""" + + v_split: Literal[1] = field(default=1, init=False) + h_split: Literal[2] = field(default=2, init=False) + split_columns = False + + +@dataclass +class FourLayers(GridSelection): + """Split data vertically into four layer regions of interest.""" + + v_split: Literal[4] = field(default=4, init=False) + h_split: Literal[1] = field(default=1, init=False) + split_rows = True + + +@dataclass +class Quadrants(GridSelection): + """Split data into four quadrant regions of interest.""" + + v_split: Literal[2] = field(default=2, init=False) + h_split: Literal[2] = field(default=2, init=False) + split_columns = False + split_rows = True diff --git a/notebooks/find_grid.ipynb b/notebooks/find_grid.ipynb new file mode 100644 index 000000000..55ed40823 --- /dev/null +++ b/notebooks/find_grid.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[nan, nan, nan, nan, 1.],\n", + " [nan, 1., 1., 3., nan],\n", + " [nan, 1., 2., 3., nan],\n", + " [nan, nan, nan, nan, nan]])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from types import SimpleNamespace\n", + "import numpy as np\n", + "import bisect\n", + "a = np.array([[np.nan, np.nan, np.nan, np.nan, 1],[np.nan, 1, 1, 3, np.nan], [np.nan, 1, 2, 3, np.nan], [np.nan, np.nan, np.nan, np.nan, np.nan]])\n", + "\n", + "data = a\n", + "display(data)\n", + "\n", + "self = SimpleNamespace()\n", + "self.h_split = 2\n", + "self.v_split = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[False, False, False, False, True],\n", + " [False, True, True, True, False],\n", + " [False, True, True, True, False],\n", + " [False, False, False, False, False]])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "is_numeric = ~np.isnan(data)\n", + "is_numeric[:]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 2, 2, 2, 1])" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "array([[1],\n", + " [2],\n", + " [3],\n", + " [4]])" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "4" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[1, 3, 5]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "numeric_cols = (~np.isnan(data)).sum(0)\n", + "cols_with_numbers = np.argwhere(numeric_cols > 0)\n", + "first_col_with_number = cols_with_numbers.min()\n", + "last_col_with_numer = cols_with_numbers.max()\n", + "\n", + "n_columns = last_col_with_numer - first_col_with_number + 1\n", + "\n", + "n_columns_per_group = n_columns / self.h_split\n", + "\n", + "display(numeric_cols, cols_with_numbers, first_col_with_number, last_col_with_numer)\n", + "\n", + "col_splits = [\n", + " first_col_with_number + bisect.bisect_left(np.arange(n_columns), c * n_columns_per_group)\n", + " for c in range(0, self.h_split)\n", + "] + [last_col_with_numer+1]\n", + "\n", + "col_splits\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[False, False, False, False, True],\n", + " [False, True, True, True, False],\n", + " [False, True, True, True, False],\n", + " [False, False, False, False, False]])" + ] + }, + "execution_count": 112, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = ~np.isnan(data)\n", + "A" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[False, False, False, False, False],\n", + " [False, True, True, False, False],\n", + " [False, True, True, False, False],\n", + " [False, False, False, False, False]])" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "array([[False, False, False, False, False],\n", + " [False, False, False, True, False],\n", + " [False, False, False, True, False],\n", + " [False, False, False, False, False]])" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "array([[False, False, False, False, True],\n", + " [False, False, False, False, False],\n", + " [False, False, False, False, False],\n", + " [False, False, False, False, False]])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import itertools\n", + "for splits in itertools.pairwise(col_splits):\n", + " matrix = A.copy()\n", + " matrix[:, :splits[0]] = False\n", + " matrix[:, splits[1]:] = False\n", + " display(matrix)" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0, 2]" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "col_splits" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test how to stack the resulting matrices, and check whether every cell is represented once\n", + "\n", + "from eitprocessing.roi_selection.gridselection import GridSelection\n", + "import numpy as np\n", + "\n", + "g = GridSelection(3, 4, False)\n", + "data = np.full((32, 32), 1)\n", + "\n", + "matrices = g.find_grid(data)\n", + "np.array_equal(np.sum(np.stack(matrices, axis=-1), axis=-1), np.ones(data.shape))" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[ True, False, False],\n", + " [False, True, False],\n", + " [False, False, True]]),\n", + " array([[ True, False],\n", + " [False, True]])]" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test way to parse string to True/False matrix\n", + "\n", + "result = 'TFF,FTF,FFT;TF,FT'\n", + "array = [np.array([tuple(row) for row in matrix.split(',')]) == 'T' for matrix in result.split(';')] \n", + "array" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0, 1, 2, 3],\n", + " [ 4, 5, 6, 7],\n", + " [ 8, 9, 10, 11]])" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test way to create a layout map of the returned matrices\n", + "\n", + "v_split = 3\n", + "h_split = 4\n", + "n_groups = v_split * h_split\n", + "np.reshape(np.arange(n_groups), (v_split, h_split))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "alive", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_gridselection.py b/tests/test_gridselection.py new file mode 100644 index 000000000..24f98bdb4 --- /dev/null +++ b/tests/test_gridselection.py @@ -0,0 +1,479 @@ +import warnings +from typing import Final +import numpy as np +import pytest +from numpy.typing import NDArray +from eitprocessing.roi_selection.gridselection import GridSelection +from eitprocessing.roi_selection.gridselection import InvalidDivision +from eitprocessing.roi_selection.gridselection import InvalidHorizontalDivision +from eitprocessing.roi_selection.gridselection import InvalidVerticalDivision +from eitprocessing.roi_selection.gridselection import MoreHorizontalGroupsThanColumns +from eitprocessing.roi_selection.gridselection import MoreVerticalGroupsThanRows +from eitprocessing.roi_selection.gridselection import UnevenHorizontalDivision +from eitprocessing.roi_selection.gridselection import UnevenVerticalDivision + + +N: Final = np.nan # shorthand for readabililty + + +def matrix_from_string(string: str) -> NDArray: + """ + Generate a of matrix from a string containing a representation of that + matrix. + + A representation of a matrix contains one character per cell that describes + the value of that cell. Rows are delimited by commas. Matrices are + delimited by semi-colons. + + The following characters are transformed to these corresponding values: + - T / 1 -> 1 + - F / 0 -> 0 + - R -> np.random.int(2, 100) + - N / any other character -> np.nan + + Examples: + >>> matrix_from_string("1T1,FNR") + array([[ 1., 1., 1.], + [ 0., nan, 93.]]) + >>> matrix_from_string("RRR,RNR") + array([[10., 68., 20.], + [46., nan, 25.]]) + """ + + str_matrix = np.array([tuple(row) for row in string.split(",")], dtype="object") + matrix = np.full(str_matrix.shape, np.nan, dtype=np.floating) + matrix[np.nonzero(str_matrix == "1")] = 1 + matrix[np.nonzero(str_matrix == "T")] = 1 + matrix[np.nonzero(str_matrix == "0")] = 0 + matrix[np.nonzero(str_matrix == "F")] = 0 + matrix = np.where( + str_matrix == "R", + np.random.default_rng().integers(2, 100, matrix.shape), + matrix, + ) + return matrix + + +def matrices_from_string(string: str) -> list[NDArray]: + """ + Generate a list of matrices from a string representation containing + multiple matrices. + + The input string should contain the representation of one or multiple + matrices to be used in `matrix_from_string()`, deliminated by `;`. + + Examples: + >>>matrices_from_string("RR,RR;RRR;1R") + [array([[64., 82.], + [40., 65.]]), + array([[56., 40., 88.]]), + array([[ 1., 76.]])] + """ + + return [matrix_from_string(part) for part in string.split(";")] + + +@pytest.mark.parametrize( + "split_vh,split_columns,split_rows,ign_nan_cols,ign_nan_rows,exception_type", + [ + ((1, 1), False, False, True, True, None), + ((1, 1), False, True, True, True, None), + ((1, 1), True, False, True, True, None), + ((1, 1), True, True, False, True, None), + ((1, 1), True, True, True, False, None), + # Vertical divider invalid + ((0, 1), False, False, True, True, InvalidVerticalDivision), + ((-1, 1), False, False, True, True, InvalidVerticalDivision), + ((1.1, 1), False, False, True, True, TypeError), + # ( Horiz)ontal divider invalid + ((1, 0), False, False, True, True, InvalidHorizontalDivision), + ((1, -1), False, False, True, True, InvalidHorizontalDivision), + ((1, 1.1), False, False, True, True, TypeError), + # split_rows invalid + ((2, 2), "not a boolean", False, True, True, TypeError), + ((2, 2), 1, False, True, True, TypeError), + ((2, 2), 0, False, True, True, TypeError), + # split_columns invalid + ((2, 2), False, "not a boolean", True, True, TypeError), + ((2, 2), False, 1, True, True, TypeError), + ((2, 2), False, 0, True, True, TypeError), + # ignore_nan_rows invalid + ((1, 1), False, False, "not a boolean", True, TypeError), + ((1, 1), False, False, 1, True, TypeError), + # ignore_nan_columns invalid + ((1, 1), False, False, True, "not a boolean", TypeError), + ((1, 1), False, False, True, 1, TypeError), + ], +) +def test_initialisation( # pylint: disable=too-many-arguments + split_vh: tuple[int, int], + split_columns: bool, + split_rows: bool, + ign_nan_cols: bool, + ign_nan_rows: bool, + exception_type: type[Exception] | None, +): + """ + Test the initialisation of GridSelection and corresponding expected errors. + + Args: + split_columns (bool): whether to allow splitting columns. + split_rows (bool): whether to allow splitting rows. + ign_nan_cols (bool): whether to ignore NaN columns. + ign_nan_rows (bool): whether to ignonre NaN rows. + exception_type (type[Exception] | None): type of exception expected to + be raised. + """ + + if exception_type is None: + GridSelection( + *split_vh, + split_columns=split_columns, + split_rows=split_rows, + ignore_nan_columns=ign_nan_cols, + ignore_nan_rows=ign_nan_rows, + ) + + else: + with pytest.raises(exception_type): + GridSelection( + *split_vh, + split_columns=split_columns, + split_rows=split_rows, + ignore_nan_columns=ign_nan_cols, + ignore_nan_rows=ign_nan_rows, + ) + + +@pytest.mark.parametrize( + "data_string,split_vh,split_rows,split_columns,warning_type", + [ + ("RR,RR", (2, 2), False, False, None), + ("RRR,RRR", (2, 2), False, False, UnevenHorizontalDivision), + ("RRR,RRR", (1, 3), False, False, None), + ("RRRR,RRRR", (1, 3), False, False, UnevenHorizontalDivision), + ("RR,RR,RR", (2, 1), False, False, UnevenVerticalDivision), + ("RR,RR,RR", (3, 1), False, False, None), + ("NN,RR,RR", (2, 1), False, False, None), + ("R", (2, 1), True, True, MoreVerticalGroupsThanRows), + ("R", (1, 2), True, True, MoreHorizontalGroupsThanColumns), + ("RRR,RRR,RRR", (4, 3), True, True, MoreVerticalGroupsThanRows), + ], +) +def test_warnings( + data_string: str, + split_vh: tuple[int, int], + split_rows: bool, + split_columns: bool, + warning_type: type[Warning] | None, +): + """ + Test for warnings generated when `find_grid()` is called. + + Args: + data_string (str): represents the input data, to be converted using + `matrices_from_string()` + split_vh (tuple[int, int]): `v_split` and `h_split`. + split_rows (bool): whether to allow splitting rows. + split_columns (bool): whether to allow splitting columns. + warning_type (type[Warning] | None): type of warning to be expected. + """ + + data = matrix_from_string(data_string) + gs = GridSelection(*split_vh, split_rows=split_rows, split_columns=split_columns) + + if warning_type is None: + # catch all warnings and raises them + with warnings.catch_warnings(): + warnings.simplefilter("error") + gs.find_grid(data) + else: + with pytest.warns(warning_type): + gs.find_grid(data) + + +@pytest.mark.parametrize( + "data_string,split_vh,split_rows,split_columns,exception_type", + [ + ("RR,RR", (2, 2), False, False, None), + ("RR,RR", (3, 1), False, False, InvalidVerticalDivision), + ("RR,RR", (1, 3), False, False, InvalidHorizontalDivision), + ("RR,RR", (3, 1), False, False, InvalidDivision), + ("RR,RR", (1, 3), False, False, InvalidDivision), + ("RR,RR", (3, 2), True, False, None), + ("RR,RR", (3, 2), False, False, InvalidVerticalDivision), + ("RR,RR", (2, 3), False, True, None), + ("RR,RR", (2, 3), False, False, InvalidHorizontalDivision), + ], +) +def test_exceptions( + data_string: str, + split_vh: tuple[int, int], + split_rows: bool, + split_columns: bool, + exception_type: type[Exception] | None, +): + """ + Test for exceptions raised when `find_grid()` is called. + + Args: + data_string (str): represents the input data, to be converted using `matrices_from_string()`. + split_vh (tuple[int, int]): `v_split` and `h_split`. + split_rows (bool): whether to allow splitting rows. + split_columns (bool): whether to allow splitting columns. + exception_type (type[Exception] | None): type of exception expected to be raised. + """ + + data = matrix_from_string(data_string) + gs = GridSelection(*split_vh, split_columns=split_columns, split_rows=split_rows) + + if exception_type is None: + gs.find_grid(data) + + else: + with pytest.raises(exception_type): + gs.find_grid(data) + + +@pytest.mark.parametrize( + "shape,split_vh,result_string", + [ + ((2, 1), (2, 1), "T,F;F,T"), + ((3, 1), (3, 1), "T,F,F;F,T,F;F,F,T"), + ((1, 3), (1, 3), "TFF;FTF;FFT"), + ((1, 3), (1, 2), "TTF;FFT"), + ( + (4, 4), + (2, 2), + "TTFF,TTFF,FFFF,FFFF;" + "FFTT,FFTT,FFFF,FFFF;" + "FFFF,FFFF,TTFF,TTFF;" + "FFFF,FFFF,FFTT,FFTT", + ), + ( + (5, 5), + (2, 2), + "TTTFF,TTTFF,TTTFF,FFFFF,FFFFF;" + "FFFTT,FFFTT,FFFTT,FFFFF,FFFFF;" + "FFFFF,FFFFF,FFFFF,TTTFF,TTTFF;" + "FFFFF,FFFFF,FFFFF,FFFTT,FFFTT", + ), + ((5, 2), (1, 2), "TF,TF,TF,TF,TF;FT,FT,FT,FT,FT"), + ((1, 9), (1, 6), "TTFFFFFFF;FFTFFFFFF;FFFTTFFFF;FFFFFTFFF;FFFFFFTTF;FFFFFFFFT"), + ((2, 2), (1, 1), "TT,TT"), + ], +) +def test_no_split_pixels_no_nans( + shape: tuple[int, int], split_vh: tuple[int, int], result_string: str +): + """ + Test `find_grid()` without split rows/columns and no NaN values. + + Args: + shape (tuple[int, int]): shape of the input data to be generated. + split_vh (tuple[int, int]): `v_split` and `h_split`. + result_string (str): represents the expected result, to be converted + using `matrices_from_string()`. + """ + + data = np.random.default_rng().integers(1, 100, shape) + result_matrices = matrices_from_string(result_string) + + gs = GridSelection(*split_vh) + matrices = gs.find_grid(data) + + num_appearances = np.sum(np.stack(matrices, axis=-1), axis=-1) + + assert len(matrices) == np.prod(split_vh) + assert np.array_equal(num_appearances, (~np.isnan(data) * 1)) + assert np.array_equal(matrices, result_matrices) + + +@pytest.mark.parametrize( + "data_string,split_vh,result_string", + [ + ("NNN,NRR,NRR", (1, 1), "NNN,NTT,NTT"), + ( + "NNN,RRR,RRR,RRR,RRR", + (2, 2), + "NNN,TTF,TTF,FFF,FFF;" + "NNN,FFT,FFT,FFF,FFF;" + "NNN,FFF,FFF,TTF,TTF;" + "NNN,FFF,FFF,FFT,FFT", + ), + ( + "NNNNNN,NNNNNN,NRRRRR,RNRRRR,NNNNNN", + (2, 2), + "NNNNNN,NNNNNN,NTTFFF,FNFFFF,NNNNNN;" + "NNNNNN,NNNNNN,NFFTTT,FNFFFF,NNNNNN;" + "NNNNNN,NNNNNN,NFFFFF,TNTFFF,NNNNNN;" + "NNNNNN,NNNNNN,NFFFFF,FNFTTT,NNNNNN", + ), + ], +) +def test_no_split_pixels_nans( + data_string: str, split_vh: tuple[int, int], result_string: str +): + """ + Test `find_grid()` without row/column splitting, with NaN values. + + Args: + data_string (str): represents the input data, to be converted using `matrices_from_string()`. + split_vh (tuple[int, int]): `v_split` and `h_split`. + result_string (str): represents the expected result, to be converted + using `matrices_from_string()`. + """ + + data = matrix_from_string(data_string) + numeric_values = np.ones(data.shape) + numeric_values[np.isnan(data)] = np.nan + result = matrices_from_string(result_string) + + v_split, h_split = split_vh + gs = GridSelection(v_split, h_split, split_rows=False, split_columns=False) + + matrices = gs.find_grid(data) + num_appearances = np.sum(np.stack(matrices, axis=-1), axis=-1) + + assert len(matrices) == h_split * v_split + assert np.array_equal(num_appearances, numeric_values, equal_nan=True) + assert np.array_equal(matrices, result, equal_nan=True) + + +@pytest.mark.parametrize( + "shape,split_vh,result", + [ + ( + (2, 3), + (1, 2), + [[[1.0, 0.5, 0], [1.0, 0.5, 0]], [[0, 0.5, 1.0], [0, 0.5, 1.0]]], + ), + ( + (3, 2), + (2, 1), + [[[1.0, 1.0], [0.5, 0.5], [0, 0]], [[0, 0], [0.5, 0.5], [1, 1]]], + ), + ( + (1, 4), + (1, 3), + [ + [[1.0, 1 / 3, 0.0, 0.0]], + [[0.0, 2 / 3, 2 / 3, 0.0]], + [[0.0, 0.0, 1 / 3, 1.0]], + ], + ), + ( + (3, 3), + (2, 2), + [ + [[1, 0.5, 0], [0.5, 0.25, 0], [0, 0, 0]], + [[0, 0.5, 1], [0, 0.25, 0.5], [0, 0, 0]], + [[0, 0, 0], [0.5, 0.25, 0], [1, 0.5, 0]], + [[0, 0, 0], [0, 0.25, 0.5], [0, 0.5, 1]], + ], + ), + ], +) +def test_split_pixels_no_nans( + shape: tuple[int, int], split_vh: tuple[int, int], result: list[list[list[float]]] +): + """ + Test `find_grid()` with split rows/columns and no NaN values. + + Args: + shape (tuple[int, int]): shape of the input data to be generated. + split_vh (tuple[int, int]): `v_split` and `h_split`. + result (str): list of lists to be converted to matrices, representing + the expected result. + """ + data = np.random.default_rng().integers(1, 100, shape) + expected_result = [np.array(r) for r in result] + + gs = GridSelection(*split_vh, split_rows=True, split_columns=True) + actual_result = gs.find_grid(data) + + num_appearances = np.sum(np.stack(actual_result, axis=-1), axis=-1) + + assert len(actual_result) == np.prod(split_vh) + assert np.array_equal(num_appearances, (~np.isnan(data) * 1)) + + # Ideally, we'd use np.array_equal() here, but due to floating point arithmetic, they values + # are off by an insignificant amount. + assert np.allclose(actual_result, expected_result) + + +@pytest.mark.parametrize( + "data_string,split_vh,result", + ( + ( + "NRRR,NRRR", + (2, 2), + [ + [[N, 1, 0.5, 0], [N, 0, 0, 0]], + [[N, 0, 0.5, 1], [N, 0, 0, 0]], + [[N, 0, 0, 0], [N, 1, 0.5, 0]], + [[N, 0, 0, 0], [N, 0, 0.5, 1]], + ], + ), + ( + "RNRR,RNRR,RNRR,NNNN", + (2, 2), + [ + [[1, N, 0, 0], [0.5, N, 0, 0], [0, N, 0, 0], [N, N, N, N]], + [[0, N, 1, 1], [0, N, 0.5, 0.5], [0, N, 0, 0], [N, N, N, N]], + [[0, N, 0, 0], [0.5, N, 0, 0], [1, N, 0, 0], [N, N, N, N]], + [[0, N, 0, 0], [0, N, 0.5, 0.5], [0, N, 1, 1], [N, N, N, N]], + ], + ), + ), +) +def test_split_pixels_nans(data_string, split_vh, result): + """ + Test `find_grid()` with row/column splitting, with NaN values. + + Args: + data_string (str): represents the input data, to be converted using `matrices_from_string()`. + split_vh (tuple[int, int]): `v_split` and `h_split`. + result (str): list of list representation of matrices, representing + the expected result. + """ + + data = matrix_from_string(data_string) + expected_result = [np.array(r) for r in result] + numeric_values = np.ones(data.shape) + numeric_values[np.isnan(data)] = np.nan + + gs = GridSelection(*split_vh, split_rows=True, split_columns=True) + actual_result = gs.find_grid(data) + + num_appearances = np.sum(np.stack(actual_result, axis=-1), axis=-1) + + assert len(actual_result) == np.prod(split_vh) + assert len(actual_result) == len(expected_result) + assert np.array_equal(num_appearances, numeric_values, equal_nan=True) + + +@pytest.mark.parametrize( + "split_vh,result", + [ + ((1, 1), [[0]]), + ((1, 2), [[0, 1]]), + ((2, 1), [[0], [1]]), + ((2, 2), [[0, 1], [2, 3]]), + ((3, 4), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), + ], +) +def test_matrix_layout(split_vh: tuple[int, int], result: list[list[int]]): + """ + Test `matrix_layout()` method. + + Args: + split_vh (tuple[int, int]): `v_split` and `h_split`. + result (list[list[int]]): list representation of a matrix, representing + the expected result. + """ + + gs = GridSelection(*split_vh) + layout = gs.matrix_layout() + + assert np.array_equal(layout, np.array(result))