From 5592a526df03c4935ada8d70fe34aad059583088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20von=20L=C3=BChmann?= Date: Thu, 15 May 2025 14:38:56 +0200 Subject: [PATCH 1/5] one version for chunked xarray matrix multiplication --- src/cedalion/imagereco/forward_model.py | 13 +- src/cedalion/xrutils.py | 183 ++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 1 deletion(-) diff --git a/src/cedalion/imagereco/forward_model.py b/src/cedalion/imagereco/forward_model.py index 80fc952d..c1f80b4c 100644 --- a/src/cedalion/imagereco/forward_model.py +++ b/src/cedalion/imagereco/forward_model.py @@ -1142,7 +1142,18 @@ def apply_inv_sensitivity( od_stacked = od.stack({"flat_channel": ["wavelength", "channel"]}) od_stacked = od_stacked.pint.dequantify() - delta_conc = inv_sens @ od_stacked + + # if od_stacked has more than 1000 time points, chunk it + if od_stacked.sizes["time"] > 1000: + delta_conc = xrutils.chunked_eff_xr_matmult( + od_stacked, + inv_sens, + contract_dim="flat_channel", + sample_dim = "time", + chunksize=1000) + else: + delta_conc = inv_sens @ od_stacked + print("nothing to chunk") # Construct a multiindex for dimension flat_vertex from chromo and vertex. # Afterwards use this multiindex to unstack flat_vertex. The resulting array diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index 8b06183b..1b573619 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -5,6 +5,9 @@ import numpy as np import pint import xarray as xr +import os +import tempfile +import shutil def pinv(array: xr.DataArray) -> xr.DataArray: @@ -231,3 +234,183 @@ def unit_stripping_is_error(is_error : bool = True): if f[0] =="error" and f[2] == pint.errors.UnitStrippedWarning: del warnings.filters[i] break + + +def chunked_eff_xr_matmult( + A: xr.DataArray, + B: xr.DataArray, + contract_dim: str, + sample_dim: str, + chunksize: int = 5000, + tmpdir: str | None = None +) -> xr.DataArray: + """Performs a large matrix multiplication of A and B, chunking A along `sample_dim` to avoid memory issues, streams each chunk to disk, and then rebuilds a full DataArray. + + Args: + A: DataArray to multiply (dims include `contract_dim` and `sample_dim` among others) + B: DataArray defining the mat-mul (dims include `contract_dim` and others) + contract_dim: name of the axis to contract (e.g. "flat_channel") + sample_dim: name of the axis along which to chunk (default "time") + chunksize: max size of each chunk along `sample_dim` + tmpdir: optional path to temp directory (auto‐created and removed if None) + + Returns: + A new DataArray of shape B_other_dims + A_other_dims, containing the result + of the matmul over `contract_dim`, with coords, dims, and attrs preserved. + """ + # Total samples & number of chunks + N = A.sizes[sample_dim] + n_chunks = int(np.ceil(N / chunksize)) + + # 1) Build a “shell” result for metadata by doing the dot on the first sample + A0 = A.isel({sample_dim: slice(0, 1)}) # keep sample_dim as length 1 + Xres = xr.dot(B, A0, dims=[contract_dim]) + + # 2) Prepare for raw numpy multiply + dims_B_not = [d for d in B.dims if d != contract_dim] + dims_A_not = [d for d in A.dims if d != contract_dim] + B_mat = B.transpose(*dims_B_not, contract_dim).values + A2 = A.transpose(contract_dim, *dims_A_not) + + # 3) Temp directory + cleanup = False + if tmpdir is None: + tmpdir = tempfile.mkdtemp() + cleanup = True + else: + os.makedirs(tmpdir, exist_ok=True) + + print(f"Large Matrix Multiplication: Processing {n_chunks} chunks...") + + # 4) Stream‐compute each chunk + file_paths = [] + for i in range(n_chunks): + start = i * chunksize + stop = min((i + 1) * chunksize, N) + A_chunk = A2.isel({sample_dim: slice(start, stop)}) + C_chunk = B_mat.dot(A_chunk.values) # raw (out_dim, chunk_len, ...) + fn = os.path.join(tmpdir, f"chunk_{i:04d}.npy") + np.save(fn, C_chunk) + file_paths.append(fn) + del A_chunk, C_chunk + print(f"Chunk {i+1}/{n_chunks} done.") + + # 5) Read back & concatenate along the sample axis + arrs = [np.load(fp) for fp in sorted(file_paths)] + axis = Xres.get_axis_num(sample_dim) + full_arr = np.concatenate(arrs, axis=axis) + + if cleanup: + shutil.rmtree(tmpdir) + + # create new coordinates + coords = { + name: coord + for name, coord in Xres.coords.items() + if sample_dim not in coord.dims + } + sample_coords = { + name: coord + for name, coord in A.coords.items() + if sample_dim in coord.dims + } + coords.update(sample_coords) + + + # 8) rebuild the DataArray using shell’s metadata + result = xr.DataArray( + data = full_arr, + dims = Xres.dims, + coords = coords, + attrs = Xres.attrs + ) + # add time coords + result.assign_coords() + return result + + """ # Determine how many samples and chunks + N = A.sizes[sample_dim] + n_chunks = int(np.ceil(N / chunksize)) + + # Prepare output dims, coords, attrs + dims_A_not = [d for d in A.dims if d != contract_dim] + dims_B_not = [d for d in B.dims if d != contract_dim] + out_dims = dims_B_not + dims_A_not + + # collect coords from B then A (excluding contract_dim) + coords = { + d: B.coords[d] + for d in B.coords if d != contract_dim + } + coords.update({ + d: A.coords[d] + for d in A.coords if d != contract_dim and d not in coords + }) + + attrs = B.attrs.copy() + + # Set up the temporary directory + cleanup = False + if tmpdir is None: + tmpdir = tempfile.mkdtemp() + cleanup = True + else: + os.makedirs(tmpdir, exist_ok=True) + + file_paths = [] + + # Prepare raw numpy matrices + # Ensure B: (out_dim, contract_dim) + dims_B = dims_B_not + [contract_dim] + B2 = B.transpose(*dims_B) + B_mat = B2.values + + # Ensure A: (contract_dim, sample_dim) up front + dims_A = [contract_dim] + dims_A_not + A2 = A.transpose(*dims_A) + + # Loop over chunks + print(f"Large Matrix Multiplication: Processing {n_chunks} chunks...") + for i in range(n_chunks): + start = i * chunksize + stop = min((i + 1) * chunksize, N) + + # slice A2 and get raw array + A_chunk = A2.isel({sample_dim: slice(start, stop)}) + A_mat = A_chunk.values + + # raw numpy mat-mul → (out_dim, chunk_len) + C_chunk = B_mat.dot(A_mat) + + # write to disk + fn = os.path.join(tmpdir, f"chunk_{i:04d}.npy") + np.save(fn, C_chunk) + file_paths.append(fn) + + # free memory + del A_chunk, A_mat, C_chunk + + print(f"Chunk {i+1}/{n_chunks} done.") + + # Read back & concatenate along sample axis (axis=1) + chunk_arrays = [np.load(fp) for fp in sorted(file_paths)] + full_arr = np.concatenate(chunk_arrays, axis=1) + # throw away any coords that reference the contracted dim + out_dims = dims_B_not + dims_A_not # e.g. ["flat_vertex","time"] + coords = { + k: v for (k, v) in coords.items() + if set(v.dims).issubset(out_dims) + } + + # cleanup if needed + if cleanup: + shutil.rmtree(tmpdir) + + # Wrap back into xarray.DataArray + result = xr.DataArray( + full_arr, + dims = out_dims, + coords = coords, + attrs = attrs + ) + return result """ From 16e0efd94df3ff4e814a6fc4480ecccf99e182af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20von=20L=C3=BChmann?= Date: Thu, 15 May 2025 15:53:07 +0200 Subject: [PATCH 2/5] clean up of cose for pull request --- src/cedalion/imagereco/forward_model.py | 12 ++- src/cedalion/xrutils.py | 111 +++--------------------- 2 files changed, 20 insertions(+), 103 deletions(-) diff --git a/src/cedalion/imagereco/forward_model.py b/src/cedalion/imagereco/forward_model.py index c1f80b4c..5f7a5ce9 100644 --- a/src/cedalion/imagereco/forward_model.py +++ b/src/cedalion/imagereco/forward_model.py @@ -1124,13 +1124,16 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray): def apply_inv_sensitivity( - od: cdt.NDTimeSeries, inv_sens: xr.DataArray + od: cdt.NDTimeSeries, inv_sens: xr.DataArray, chunk: bool = True, ) -> tuple[xr.DataArray, xr.DataArray]: """Apply the inverted sensitivity matrix to optical density data. Args: od: time series of optical density data inv_sens: the inverted sensitivity matrix + chunk: optional piecewise matrix multiplication. + default True, gets active if more than 1000 time samples are to be converted. + False force-skips chunking. Returns: Two DataArrays for the brain and scalp with the reconcstructed time series per @@ -1142,18 +1145,19 @@ def apply_inv_sensitivity( od_stacked = od.stack({"flat_channel": ["wavelength", "channel"]}) od_stacked = od_stacked.pint.dequantify() + # for image recon we have time-series data either with "time" or "reltime" dimension + sample_dim = next((d for d in ["time","reltime"] if d in od_stacked.dims), None) # if od_stacked has more than 1000 time points, chunk it - if od_stacked.sizes["time"] > 1000: + if (od_stacked.sizes["time"] > 1000) and chunk: delta_conc = xrutils.chunked_eff_xr_matmult( od_stacked, inv_sens, contract_dim="flat_channel", - sample_dim = "time", + sample_dim = sample_dim, chunksize=1000) else: delta_conc = inv_sens @ od_stacked - print("nothing to chunk") # Construct a multiindex for dimension flat_vertex from chromo and vertex. # Afterwards use this multiindex to unstack flat_vertex. The resulting array diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index 1b573619..bb63467f 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -244,7 +244,7 @@ def chunked_eff_xr_matmult( chunksize: int = 5000, tmpdir: str | None = None ) -> xr.DataArray: - """Performs a large matrix multiplication of A and B, chunking A along `sample_dim` to avoid memory issues, streams each chunk to disk, and then rebuilds a full DataArray. + """Performs a large matrix multiplication of A and B, chunking A along `sample_dim`; to avoid memory issues, streams each chunk to disk, and then rebuilds a full DataArray. Args: A: DataArray to multiply (dims include `contract_dim` and `sample_dim` among others) @@ -255,24 +255,25 @@ def chunked_eff_xr_matmult( tmpdir: optional path to temp directory (auto‐created and removed if None) Returns: - A new DataArray of shape B_other_dims + A_other_dims, containing the result - of the matmul over `contract_dim`, with coords, dims, and attrs preserved. + A new DataArray of containing the result of the matrix multiplication over `contract_dim`, + with coords, dims, and attrs preserved. Should yield the same result as `xr.dot(A, B, dims=[contract_dim])` + but at increased speed and with a much lower memory footprint. """ # Total samples & number of chunks N = A.sizes[sample_dim] n_chunks = int(np.ceil(N / chunksize)) - # 1) Build a “shell” result for metadata by doing the dot on the first sample - A0 = A.isel({sample_dim: slice(0, 1)}) # keep sample_dim as length 1 + # Build a “shell” result for metadata by doing the dot on the first sample + A0 = A.isel({sample_dim: slice(0, 1)}) Xres = xr.dot(B, A0, dims=[contract_dim]) - # 2) Prepare for raw numpy multiply + # Prepare for raw numpy multiply dims_B_not = [d for d in B.dims if d != contract_dim] dims_A_not = [d for d in A.dims if d != contract_dim] B_mat = B.transpose(*dims_B_not, contract_dim).values A2 = A.transpose(contract_dim, *dims_A_not) - # 3) Temp directory + # Create Temp directory cleanup = False if tmpdir is None: tmpdir = tempfile.mkdtemp() @@ -282,7 +283,7 @@ def chunked_eff_xr_matmult( print(f"Large Matrix Multiplication: Processing {n_chunks} chunks...") - # 4) Stream‐compute each chunk + # Stream‐compute each chunk file_paths = [] for i in range(n_chunks): start = i * chunksize @@ -295,7 +296,7 @@ def chunked_eff_xr_matmult( del A_chunk, C_chunk print(f"Chunk {i+1}/{n_chunks} done.") - # 5) Read back & concatenate along the sample axis + # Read back & concatenate along the sample axis arrs = [np.load(fp) for fp in sorted(file_paths)] axis = Xres.get_axis_num(sample_dim) full_arr = np.concatenate(arrs, axis=axis) @@ -303,7 +304,7 @@ def chunked_eff_xr_matmult( if cleanup: shutil.rmtree(tmpdir) - # create new coordinates + # create set of coordinates coords = { name: coord for name, coord in Xres.coords.items() @@ -316,8 +317,7 @@ def chunked_eff_xr_matmult( } coords.update(sample_coords) - - # 8) rebuild the DataArray using shell’s metadata + # rebuild the DataArray using the Xres metadata result = xr.DataArray( data = full_arr, dims = Xres.dims, @@ -327,90 +327,3 @@ def chunked_eff_xr_matmult( # add time coords result.assign_coords() return result - - """ # Determine how many samples and chunks - N = A.sizes[sample_dim] - n_chunks = int(np.ceil(N / chunksize)) - - # Prepare output dims, coords, attrs - dims_A_not = [d for d in A.dims if d != contract_dim] - dims_B_not = [d for d in B.dims if d != contract_dim] - out_dims = dims_B_not + dims_A_not - - # collect coords from B then A (excluding contract_dim) - coords = { - d: B.coords[d] - for d in B.coords if d != contract_dim - } - coords.update({ - d: A.coords[d] - for d in A.coords if d != contract_dim and d not in coords - }) - - attrs = B.attrs.copy() - - # Set up the temporary directory - cleanup = False - if tmpdir is None: - tmpdir = tempfile.mkdtemp() - cleanup = True - else: - os.makedirs(tmpdir, exist_ok=True) - - file_paths = [] - - # Prepare raw numpy matrices - # Ensure B: (out_dim, contract_dim) - dims_B = dims_B_not + [contract_dim] - B2 = B.transpose(*dims_B) - B_mat = B2.values - - # Ensure A: (contract_dim, sample_dim) up front - dims_A = [contract_dim] + dims_A_not - A2 = A.transpose(*dims_A) - - # Loop over chunks - print(f"Large Matrix Multiplication: Processing {n_chunks} chunks...") - for i in range(n_chunks): - start = i * chunksize - stop = min((i + 1) * chunksize, N) - - # slice A2 and get raw array - A_chunk = A2.isel({sample_dim: slice(start, stop)}) - A_mat = A_chunk.values - - # raw numpy mat-mul → (out_dim, chunk_len) - C_chunk = B_mat.dot(A_mat) - - # write to disk - fn = os.path.join(tmpdir, f"chunk_{i:04d}.npy") - np.save(fn, C_chunk) - file_paths.append(fn) - - # free memory - del A_chunk, A_mat, C_chunk - - print(f"Chunk {i+1}/{n_chunks} done.") - - # Read back & concatenate along sample axis (axis=1) - chunk_arrays = [np.load(fp) for fp in sorted(file_paths)] - full_arr = np.concatenate(chunk_arrays, axis=1) - # throw away any coords that reference the contracted dim - out_dims = dims_B_not + dims_A_not # e.g. ["flat_vertex","time"] - coords = { - k: v for (k, v) in coords.items() - if set(v.dims).issubset(out_dims) - } - - # cleanup if needed - if cleanup: - shutil.rmtree(tmpdir) - - # Wrap back into xarray.DataArray - result = xr.DataArray( - full_arr, - dims = out_dims, - coords = coords, - attrs = attrs - ) - return result """ From d2aade2b35fc1d373b068bc1b0ccc04c6b774bbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20von=20L=C3=BChmann?= Date: Thu, 15 May 2025 15:59:11 +0200 Subject: [PATCH 3/5] Update xrutils.py --- src/cedalion/xrutils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index bb63467f..0e1c52db 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -258,6 +258,9 @@ def chunked_eff_xr_matmult( A new DataArray of containing the result of the matrix multiplication over `contract_dim`, with coords, dims, and attrs preserved. Should yield the same result as `xr.dot(A, B, dims=[contract_dim])` but at increased speed and with a much lower memory footprint. + + Initial Contirbutors: + - Alexander von Lühmann | vonluehmann@tu-berlin.de | 2025 """ # Total samples & number of chunks N = A.sizes[sample_dim] From 7a39320ac9a0d2d079ab6d4b77907287eb0dedcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20von=20L=C3=BChmann?= Date: Thu, 15 May 2025 15:59:27 +0200 Subject: [PATCH 4/5] Update xrutils.py --- src/cedalion/xrutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index 0e1c52db..46c192c1 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -259,7 +259,7 @@ def chunked_eff_xr_matmult( with coords, dims, and attrs preserved. Should yield the same result as `xr.dot(A, B, dims=[contract_dim])` but at increased speed and with a much lower memory footprint. - Initial Contirbutors: + Initial Contributors: - Alexander von Lühmann | vonluehmann@tu-berlin.de | 2025 """ # Total samples & number of chunks From baeca25fde6b07455d35cb882b7a6fe17620baaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20von=20L=C3=BChmann?= Date: Thu, 15 May 2025 16:00:25 +0200 Subject: [PATCH 5/5] Update xrutils.py --- src/cedalion/xrutils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index 46c192c1..8e37ab66 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -249,9 +249,9 @@ def chunked_eff_xr_matmult( Args: A: DataArray to multiply (dims include `contract_dim` and `sample_dim` among others) B: DataArray defining the mat-mul (dims include `contract_dim` and others) - contract_dim: name of the axis to contract (e.g. "flat_channel") - sample_dim: name of the axis along which to chunk (default "time") - chunksize: max size of each chunk along `sample_dim` + contract_dim: name of the dimension to contract (e.g. "flat_channel") + sample_dim: name of the dimension along which to chunk (e.g. "time") + chunksize: max size of each chunk along dimension `sample_dim` tmpdir: optional path to temp directory (auto‐created and removed if None) Returns: