From aac8322fe8bb626d60ac00212986e709ba72fd0e Mon Sep 17 00:00:00 2001 From: faberno Date: Thu, 26 Dec 2024 01:40:36 +0100 Subject: [PATCH] refactor get_array_binary_mask and combine_sensor_data --- kwave/utils/kwave_array.py | 375 +++++++++++++++++++------------------ 1 file changed, 193 insertions(+), 182 deletions(-) diff --git a/kwave/utils/kwave_array.py b/kwave/utils/kwave_array.py index 3d666dbb1..11608db34 100644 --- a/kwave/utils/kwave_array.py +++ b/kwave/utils/kwave_array.py @@ -1,8 +1,9 @@ import logging import time +from collections import Counter from dataclasses import dataclass from math import ceil -from typing import Optional +from typing import Optional, Sequence import numpy as np from numpy import arcsin, pi, cos, size, array @@ -400,39 +401,140 @@ def add_line_element(self, start_point, end_point): ) ) - def get_element_grid_weights(self, kgrid, element_num): - return self.get_off_grid_points(kgrid, element_num, False) + def get_element_grid_weights(self, kgrid, element_num, sparse=False): + return self.get_off_grid_points(kgrid, [element_num], False, sparse=sparse) - def get_element_binary_mask(self, kgrid, element_num): - return self.get_off_grid_points(kgrid, element_num, True) + def get_element_binary_mask(self, kgrid, element_num, sparse=False): + return self.get_off_grid_points(kgrid, [element_num], True, sparse=sparse) def get_array_grid_weights(self, kgrid): self.check_for_elements() - if kgrid.Nz >= 1: - grid_weights = np.zeros((kgrid.Nx, kgrid.Ny, kgrid.Nz)) - else: - grid_weights = np.zeros((kgrid.Nx, kgrid.Ny)) - - assert len(grid_weights.shape) == self.dim, "Grid weights shape must match kArray dimensions." - - for ind in range(self.number_elements): - grid_weights += self.get_off_grid_points(kgrid, ind, False) - + grid_weights = self.get_off_grid_points(kgrid, element_indices=range(self.number_elements), mask_only=False) assert len(grid_weights.shape) == self.dim, "Grid weights shape must match kArray dimensions." - return grid_weights def get_array_binary_mask(self, kgrid): self.check_for_elements() - mask = np.squeeze(np.zeros((kgrid.Nx, kgrid.Ny, max(kgrid.Nz, 1)), dtype=bool)) + mask = self.get_off_grid_points(kgrid, element_indices=range(self.number_elements), mask_only=True) + return mask - for ind in range(self.number_elements): - grid_weights = self.get_off_grid_points(kgrid, ind, True) - mask = np.bitwise_or(np.squeeze(mask), grid_weights) + def get_integration_points(self, kgrid: kWaveGrid, element_indices: Sequence): + # check the array has elements + self.check_for_elements() - return mask + # check inputs + assert isinstance(kgrid, kWaveGrid) + assert len(element_indices) <= len(self.elements) + assert all([idx < len(self.elements) for idx in element_indices]) + + all_integration_points = [] + all_scales = [] + + for element_idx in element_indices: + element = self.elements[element_idx] + + # compute measure (length/area/volume) in grid squares (assuming dx = dy = dz) + m_grid = element.measure / (kgrid.dx) ** (element.dim) + + # get number of integration points + if element.type == "custom": + # assign number of integration points directly + m_integration = element.integration_points.shape[1] + else: + # compute the number of integration points using the upsampling rate + m_integration = ceil(m_grid * self.upsampling_rate) + + # compute integration points covering element + if element.type == "annulus": + # compute points using make_cart_spherical_segment + integration_points = make_cart_spherical_segment( + self.affine(element.position), + element.radius_of_curvature, + element.inner_diameter, + element.outer_diameter, + self.affine(element.focus_position), + m_integration, + ) + + elif element.type == "arc": + # compute points using make_cart_arc + integration_points = make_cart_arc( + self.affine(element.position), + element.radius_of_curvature, + element.diameter, + self.affine(element.focus_position), + m_integration, + ) + + elif element.type == "bowl": + # compute points using make_cart_bowl + integration_points = make_cart_bowl( + self.affine(element.position), + element.radius_of_curvature, + element.diameter, + self.affine(element.focus_position), + m_integration, + ) + + elif element.type == "custom": + # directly assign integration points + integration_points = element.integration_points + + elif element.type == "disc": + # compute points using make_cart_disc + integration_points = make_cart_disc( + self.affine(element.position), + element.diameter / 2, + self.affine(element.focus_position), + m_integration, + False, + self.use_spiral_disc_points, + ) + + elif element.type == "rect": + # compute points using make_cart_rect + integration_points = make_cart_rect( + self.affine(element.position), + element.length, + element.width, + element.orientation, + m_integration, + ) + + elif element.type == "line": + # get distance between points in each dimension + d = (element.end_point - element.start_point) / m_integration + + # compute a set of uniformly spaced Cartesian points + # covering the line using linspace, where the end + # points are offset by half the point spacing + if self.dim == 1: + integration_points = np.linspace(element.start_point + d[0] / 2, element.end_point - d[0] / 2, m_integration) + elif self.dim == 2: + px = np.linspace(element.start_point[0] + d[0] / 2, element.end_point[0] - d[0] / 2, m_integration) + py = np.linspace(element.start_point[1] + d[1] / 2, element.end_point[1] - d[1] / 2, m_integration) + integration_points = np.array([px, py]) + elif self.dim == 3: + px = np.linspace(element.start_point[0] + d[0] / 2, element.end_point[0] - d[0] / 2, m_integration) + py = np.linspace(element.start_point[1] + d[1] / 2, element.end_point[1] - d[1] / 2, m_integration) + pz = np.linspace(element.start_point[2] + d[2] / 2, element.end_point[2] - d[2] / 2, m_integration) + integration_points = np.array([px, py, pz]) + + else: + raise ValueError(f"{element.type} is not a valid array element type.") + + # recompute actual number of points + m_integration = integration_points.shape[1] + + # compute scaling factor + scale = m_grid / m_integration + + all_integration_points.append(integration_points) + all_scales.append(np.ones(m_integration) * scale) + + return np.hstack(all_integration_points), np.hstack(all_scales) def check_for_elements(self): if self.number_elements == 0: @@ -451,158 +553,43 @@ def affine(self, vec): vec = np.matmul(self.array_transformation, vec) return vec[:-1] - def get_off_grid_points(self, kgrid, element_num, mask_only): + def get_off_grid_points(self, kgrid: kWaveGrid, element_indices: Sequence, mask_only: bool, sparse: bool = False): # check the array has elements self.check_for_elements() # check inputs assert isinstance(kgrid, kWaveGrid) - assert 0 <= element_num <= self.number_elements - 1 - - # compute measure (length/area/volume) in grid squares (assuming dx = dy = dz) - m_grid = self.elements[element_num].measure / (kgrid.dx) ** (self.elements[element_num].dim) + assert len(element_indices) <= len(self.elements) + assert all([idx < len(self.elements) for idx in element_indices]) - # get number of integration points - if self.elements[element_num].type == "custom": - # assign number of integration points directly - m_integration = self.elements[element_num].integration_points.shape[1] - else: - # compute the number of integration points using the upsampling rate - m_integration = ceil(m_grid * self.upsampling_rate) - - # compute integration points covering element - if self.elements[element_num].type == "annulus": - # compute points using make_cart_spherical_segment - integration_points = make_cart_spherical_segment( - self.affine(self.elements[element_num].position), - self.elements[element_num].radius_of_curvature, - self.elements[element_num].inner_diameter, - self.elements[element_num].outer_diameter, - self.affine(self.elements[element_num].focus_position), - m_integration, - ) - - elif self.elements[element_num].type == "arc": - # compute points using make_cart_arc - integration_points = make_cart_arc( - self.affine(self.elements[element_num].position), - self.elements[element_num].radius_of_curvature, - self.elements[element_num].diameter, - self.affine(self.elements[element_num].focus_position), - m_integration, - ) - - elif self.elements[element_num].type == "bowl": - # compute points using make_cart_bowl - integration_points = make_cart_bowl( - self.affine(self.elements[element_num].position), - self.elements[element_num].radius_of_curvature, - self.elements[element_num].diameter, - self.affine(self.elements[element_num].focus_position), - m_integration, - ) - - elif self.elements[element_num].type == "custom": - # directly assign integration points - integration_points = self.elements[element_num].integration_points - - elif self.elements[element_num].type == "disc": - # compute points using make_cart_disc - integration_points = make_cart_disc( - self.affine(self.elements[element_num].position), - self.elements[element_num].diameter / 2, - self.affine(self.elements[element_num].focus_position), - m_integration, - False, - self.use_spiral_disc_points, - ) - - elif self.elements[element_num].type == "rect": - # compute points using make_cart_rect - integration_points = make_cart_rect( - self.affine(self.elements[element_num].position), - self.elements[element_num].length, - self.elements[element_num].width, - self.elements[element_num].orientation, - m_integration, - ) - - elif self.elements[element_num].type == "line": - # get distance between points in each dimension - d = (self.elements[element_num].end_point - self.elements[element_num].start_point) / m_integration - - # compute a set of uniformly spaced Cartesian points - # covering the line using linspace, where the end - # points are offset by half the point spacing - if self.dim == 1: - integration_points = np.linspace( - self.elements[element_num].start_point + d[0] / 2, self.elements[element_num].end_point - d[0] / 2, m_integration - ) - elif self.dim == 2: - px = np.linspace( - self.elements[element_num].start_point[0] + d[0] / 2, self.elements[element_num].end_point[0] - d[0] / 2, m_integration - ) - py = np.linspace( - self.elements[element_num].start_point[1] + d[1] / 2, self.elements[element_num].end_point[1] - d[1] / 2, m_integration - ) - integration_points = np.array([px, py]) - elif self.dim == 3: - px = np.linspace( - self.elements[element_num].start_point[0] + d[0] / 2, self.elements[element_num].end_point[0] - d[0] / 2, m_integration - ) - py = np.linspace( - self.elements[element_num].start_point[1] + d[1] / 2, self.elements[element_num].end_point[1] - d[1] / 2, m_integration - ) - pz = np.linspace( - self.elements[element_num].start_point[2] + d[2] / 2, self.elements[element_num].end_point[2] - d[2] / 2, m_integration - ) - integration_points = np.array([px, py, pz]) - - else: - raise ValueError(f"{self.elements[element_num].type} is not a valid array element type.") - - # recompute actual number of points - m_integration = integration_points.shape[1] - - # compute scaling factor - scale = m_grid / m_integration + integration_points, scales = self.get_integration_points(kgrid, element_indices) if self.axisymmetric: # create new expanded grid - kgrid_expanded = kWaveGrid(Vector([kgrid.Nx, 2 * kgrid.Ny]), Vector([kgrid.dx, kgrid.dy])) - - # remove integration points which are outside grid - integration_points = trim_cart_points(kgrid_expanded, integration_points) - - # calculate grid weights from BLIs centered on the integration points - grid_weights = off_grid_points( - kgrid_expanded, - integration_points, - scale, - bli_tolerance=self.bli_tolerance, - bli_type=self.bli_type, - mask_only=mask_only, - single_precision=self.single_precision, - ) + kgrid = kWaveGrid(Vector([kgrid.Nx, 2 * kgrid.Ny]), Vector([kgrid.dx, kgrid.dy])) + + # remove integration points which are outside grid + integration_points = trim_cart_points(kgrid, integration_points) + + # calculate grid weights from BLIs centered on the integration points + grid_weights = off_grid_points( + kgrid, + integration_points, + scales, + bli_tolerance=self.bli_tolerance, + bli_type=self.bli_type, + mask_only=mask_only, + single_precision=self.single_precision, + sparse=sparse, + ) + + if self.axisymmetric: + if sparse: + raise NotImplementedError("Sparse output not yet supported for axisymmetric arrays.") # TODO: sparse + axisymmetric # keep points in the positive y domain grid_weights = grid_weights[:, kgrid.Ny :] - else: - # remove integration points which are outside grid - integration_points = trim_cart_points(kgrid, integration_points) - - # calculate grid weights from BLIs centered on the integration points - grid_weights = off_grid_points( - kgrid, - integration_points, - scale, - bli_tolerance=self.bli_tolerance, - bli_type=self.bli_type, - mask_only=mask_only, - single_precision=self.single_precision, - ) - return grid_weights def get_distributed_source_signal(self, kgrid, source_signal): @@ -650,10 +637,15 @@ def get_distributed_source_signal(self, kgrid, source_signal): return distributed_source_signal - def combine_sensor_data(self, kgrid, sensor_data): + def combine_sensor_data(self, kgrid, sensor_data, mask: Optional[np.ndarray] = None): self.check_for_elements() - mask = self.get_array_binary_mask(kgrid) + if mask is None: + mask = self.get_array_binary_mask(kgrid) + + assert mask.dtype == bool + assert np.all(mask.shape == kgrid.N) + mask_ind = matlab_find(mask).squeeze(axis=-1) Nt = np.shape(sensor_data)[1] @@ -663,19 +655,14 @@ def combine_sensor_data(self, kgrid, sensor_data): combined_sensor_data = np.zeros((self.number_elements, Nt)) for element_num in range(self.number_elements): - source_weights = self.get_element_grid_weights(kgrid, element_num) - - element_mask_ind = matlab_find(np.array(source_weights), val=0, mode="neq").squeeze(axis=-1) - - local_ind = np.isin(mask_ind, element_mask_ind) + source_weights = self.get_element_grid_weights(kgrid, element_num, sparse=True) - combined_sensor_data[element_num, :] = np.sum( - sensor_data[local_ind] * matlab_mask(source_weights, element_mask_ind - 1), axis=0 - ) + local_ind = np.isin(mask_ind, list(source_weights.keys())) + weights = [source_weights[i] for i in mask_ind[local_ind]] m_grid = self.elements[element_num].measure / (kgrid.dx) ** (self.elements[element_num].dim) - combined_sensor_data[element_num, :] = combined_sensor_data[element_num, :] / m_grid + combined_sensor_data[element_num, :] = np.sum(sensor_data[local_ind] * np.asarray(weights)[:, None], axis=0) / m_grid return combined_sensor_data @@ -708,7 +695,16 @@ def get_element_positions(self): def off_grid_points( - kgrid, points, scale=1, bli_tolerance=0.1, bli_type="sinc", mask_only=False, single_precision=False, debug=False, display_wait_bar=False + kgrid, + points, + scale=1, + bli_tolerance=0.1, + bli_type="sinc", + mask_only=False, + single_precision=False, + sparse=False, + debug=False, + display_wait_bar=False, ): wait_bar_update_freq = 100 @@ -772,12 +768,21 @@ def off_grid_points( else: mask_type = np.float64 - if kgrid.dim == 1: - mask = np.zeros((kgrid.Nx, 1), dtype=mask_type) - elif kgrid.dim == 2: - mask = np.zeros((kgrid.Nx, kgrid.Ny), dtype=mask_type) - elif kgrid.dim == 3: - mask = np.zeros((kgrid.Nx, kgrid.Ny, kgrid.Nz), dtype=mask_type) + if bli_tolerance == 0 and sparse: + raise ValueError("Sparse off-grid calculation is not available when bli_tolerance=0") + + if not sparse: + if kgrid.dim == 1: + mask = np.zeros((kgrid.Nx, 1), dtype=mask_type) + elif kgrid.dim == 2: + mask = np.zeros((kgrid.Nx, kgrid.Ny), dtype=mask_type) + elif kgrid.dim == 3: + mask = np.zeros((kgrid.Nx, kgrid.Ny, kgrid.Nz), dtype=mask_type) + else: + if mask_only: + mask = set() + else: + mask = Counter({}) # display wait bar if display_wait_bar: @@ -850,7 +855,11 @@ def off_grid_points( if mask_only: # add current points to the mask - mask = matlab_assign(mask, ind - 1, True) + if not sparse: + mask = matlab_assign(mask, ind - 1, True) + else: + mask.update(ind) + else: # evaluate a BLI centered on point at grid nodes XYZ if scalar_dxyz: @@ -869,10 +878,12 @@ def off_grid_points( if kgrid.nonuniform: current_mask_t = mask_t * BLIscale - updated_mask_value = matlab_mask(mask, ind - 1).squeeze(axis=-1) + scale[point_ind] * current_mask_t # add this contribution to the overall source mask - mask = matlab_assign(mask, ind - 1, updated_mask_value) - + if not sparse: + updated_mask_value = matlab_mask(mask, ind - 1).squeeze(axis=-1) + scale[point_ind] * current_mask_t + mask = matlab_assign(mask, ind - 1, updated_mask_value) + else: + mask.update(dict(zip(ind, scale[point_ind] * current_mask_t))) # update the waitbar if display_wait_bar and (point_ind % wait_bar_update_freq == 0): tqdm.update(wait_bar_update_freq)