From b497c0fd90277f3fdab94c3dc7f54016706ac922 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 9 Oct 2025 12:30:11 -0700 Subject: [PATCH 1/6] Start refactoring logic for disk --- simpeg/dask/potential_fields/base.py | 101 ++++++++++++++------------- 1 file changed, 54 insertions(+), 47 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 2229980763..ac26701050 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -47,12 +47,49 @@ def block_compute(sim, rows, components): return np.vstack(block) +def storage_formatter( + rows: list[np.ndarray], + device: str, + chunk_format="rows", + sens_name: str = "./sensitivities.zarr", + max_chunk_size: float = 256, +): + + if device == "forward_only": + return array.concatenate(rows) + elif device == "disk": + stack = array.vstack(rows) + # Chunking options + if chunk_format == "row": + config.set({"array.chunk-size": f"{max_chunk_size}MiB"}) + # Autochunking by rows is faster and more memory efficient for + # very large problems sensitivty and forward calculations + stack = stack.rechunk({0: "auto", 1: -1}) + elif chunk_format == "equal": + # Manual chunks for equal number of blocks along rows and columns. + # Optimal for Jvec and Jtvec operations + row_chunk, col_chunk = compute_chunk_sizes(*stack.shape, max_chunk_size) + stack = stack.rechunk((row_chunk, col_chunk)) + else: + # Auto chunking by columns is faster for Inversions + config.set({"array.chunk-size": f"{max_chunk_size}MiB"}) + stack = stack.rechunk({0: -1, 1: "auto"}) + + return array.to_zarr(stack, sens_name, return_stored=True, overwrite=True) + else: + return np.vstack(rows) + + def linear_operator(self): forward_only = self.store_sensitivities == "forward_only" n_cells = self.nC if getattr(self, "model_type", None) == "vector": n_cells *= 3 + if self.store_sensitivities == "disk" and os.path.exists(self.sensitivity_path): + kernel = array.from_zarr(self.sensitivity_path) + return kernel + n_components = len(self.survey.components) n_blocks = np.ceil( (n_cells * n_components * self.survey.receiver_locations.shape[0] * 8.0 * 1e-6) @@ -94,56 +131,26 @@ def linear_operator(self): ) if client: - if forward_only: - return np.hstack(client.gather(rows)) - return np.vstack(client.gather(rows)) - - if forward_only: - stack = array.concatenate(rows) + future = client.submit( + storage_formatter, + rows, + device=self.store_sensitivities, + chunk_format=self.chunk_format, + sens_name=self.sensitivity_path, + max_chunk_size=self.max_chunk_size, + workers=worker, + ) + kernel = client.gather(future) else: - stack = array.vstack(rows) - # Chunking options - if self.chunk_format == "row": - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - # Autochunking by rows is faster and more memory efficient for - # very large problems sensitivty and forward calculations - stack = stack.rechunk({0: "auto", 1: -1}) - elif self.chunk_format == "equal": - # Manual chunks for equal number of blocks along rows and columns. - # Optimal for Jvec and Jtvec operations - row_chunk, col_chunk = compute_chunk_sizes( - *stack.shape, self.max_chunk_size - ) - stack = stack.rechunk((row_chunk, col_chunk)) - else: - # Auto chunking by columns is faster for Inversions - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - stack = stack.rechunk({0: -1, 1: "auto"}) - - if self.store_sensitivities == "disk": - sens_name = os.path.join(self.sensitivity_path, "sensitivity.zarr") - if os.path.exists(sens_name): - kernel = array.from_zarr(sens_name) - if np.all( - np.r_[ - np.any(np.r_[kernel.chunks[0]] == stack.chunks[0]), - np.any(np.r_[kernel.chunks[1]] == stack.chunks[1]), - np.r_[kernel.shape] == np.r_[stack.shape], - ] - ): - # Check that loaded kernel matches supplied data and mesh - print("Zarr file detected with same shape and chunksize ... re-loading") - return kernel - - print("Writing Zarr file to disk") with ProgressBar(): - print("Saving kernel to zarr: " + sens_name) - kernel = array.to_zarr( - stack, sens_name, compute=True, return_stored=True, overwrite=True - ) + kernel = storage_formatter( + rows, + device=self.store_sensitivities, + chunk_format=self.chunk_format, + sens_name=self.sensitivity_path, + max_chunk_size=self.max_chunk_size, + ).compute() - with ProgressBar(): - kernel = stack.compute() return kernel From 0e25c01019e9425deae0e2a5541007c9629954dd Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 9 Oct 2025 16:29:13 -0700 Subject: [PATCH 2/6] Re-working writting to disk by workers concurrently. Need clean up --- simpeg/dask/potential_fields/base.py | 85 +++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 14 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index ac26701050..ae088f62d7 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -4,8 +4,10 @@ import os from dask import delayed, array, config + from dask.diagnostics import ProgressBar from ..utils import compute_chunk_sizes +import zarr _chunk_format = "row" @@ -27,9 +29,12 @@ def chunk_format(self, other): def dpred(self, m=None, f=None): if m is not None: self.model = m - if f is not None: - return f - return self.fields(self.model) + if f is None: + f = self.fields(self.model) + + if isinstance(f, array.Array): + return np.asarray(f) + return f def residual(self, m, dobs, f=None): @@ -53,9 +58,23 @@ def storage_formatter( chunk_format="rows", sens_name: str = "./sensitivities.zarr", max_chunk_size: float = 256, + compute: bool = True, ): + """ + Format the storage of the sensitivity matrix. + + :param rows: List of dask arrays representing blocks of the sensitivity matrix. + :param device: Storage option, either "forward_only", "disk", or "memory". + :param chunk_format: Chunking format for disk storage, either "row", "equal + or "auto". + :param sens_name: File path to store the sensitivity matrix if device is "disk". + :param max_chunk_size: Maximum chunk size in MiB for disk storage. + :param compute: If True, compute the dask array before returning. + """ if device == "forward_only": + if compute: + return np.hstack(rows) return array.concatenate(rows) elif device == "disk": stack = array.vstack(rows) @@ -75,9 +94,20 @@ def storage_formatter( config.set({"array.chunk-size": f"{max_chunk_size}MiB"}) stack = stack.rechunk({0: -1, 1: "auto"}) - return array.to_zarr(stack, sens_name, return_stored=True, overwrite=True) + return array.to_zarr( + stack, sens_name, compute=False, return_stored=False, overwrite=True + ) else: - return np.vstack(rows) + if compute: + return np.vstack(rows) + + return array.vstack(rows) + + +def set_orthogonal_selection(row, Jmatrix, count): + n_rows = row.shape[0] + Jmatrix[count : count + n_rows, :] = row + return None def linear_operator(self): @@ -131,17 +161,40 @@ def linear_operator(self): ) if client: - future = client.submit( - storage_formatter, - rows, - device=self.store_sensitivities, - chunk_format=self.chunk_format, - sens_name=self.sensitivity_path, - max_chunk_size=self.max_chunk_size, - workers=worker, - ) + + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, n_cells), + chunks=(self.max_chunk_size, n_cells), + ) + + count = 0 + future = [] + for row, block in zip(rows, block_split): + row_chunks = block.shape[0] + future.append( + client.submit(set_orthogonal_selection, row, Jmatrix, count) + ) + + count += row_chunks + + else: + future = client.submit( + storage_formatter, + rows, + device=self.store_sensitivities, + chunk_format=self.chunk_format, + sens_name=self.sensitivity_path, + max_chunk_size=self.max_chunk_size, + workers=worker, + ) + kernel = client.gather(future) + else: + with ProgressBar(): kernel = storage_formatter( rows, @@ -149,8 +202,12 @@ def linear_operator(self): chunk_format=self.chunk_format, sens_name=self.sensitivity_path, max_chunk_size=self.max_chunk_size, + compute=False, ).compute() + if self.store_sensitivities == "disk" and os.path.exists(self.sensitivity_path): + kernel = array.from_zarr(self.sensitivity_path) + return kernel From 323af877f7d66944861f5dc0c2d87092d692480c Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 10 Oct 2025 11:38:32 -0700 Subject: [PATCH 3/6] Update disk storage for dc, fem and potential fields --- .../frequency_domain/simulation.py | 9 +- .../static/resistivity/simulation.py | 23 ++- simpeg/dask/potential_fields/base.py | 165 ++++++------------ 3 files changed, 73 insertions(+), 124 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 32deed7f7c..82458640b3 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -270,7 +270,6 @@ def compute_J(self, m, f=None): addresses_chunks, client, worker, - store_sensitivities=self.store_sensitivities, ) for A in Ainv.values(): @@ -295,7 +294,6 @@ def parallel_block_compute( addresses, client, worker=None, - store_sensitivities="disk", ): m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() @@ -354,11 +352,8 @@ def parallel_block_compute( else: block = compute(array.vstack(block_delayed))[0] - if store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (indices, slice(None)), - block, - ) + if isinstance(Jmatrix, zarr.Array): + Jmatrix.set_orthogonal_selection((indices, slice(None)), block) else: # Dask process to compute row and store Jmatrix[indices, :] = block diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 1a49889a1b..90ae36601e 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -4,6 +4,7 @@ from .....utils import Zero +import os import dask.array as da import numpy as np from scipy import sparse as sp @@ -42,23 +43,29 @@ def compute_J(self, m, f=None): f, Ainv = self.fields(m=m, return_Ainv=True) - m_size = m.size + n_cells = m.size row_chunks = int( np.ceil( float(self.survey.nD) - / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) + / np.ceil( + float(n_cells) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size + ) ) ) if self.store_sensitivities == "disk": + + if os.path.exists(self.sensitivity_path): + return da.from_zarr(self.sensitivity_path) + Jmatrix = zarr.open( - self.sensitivity_path + "J.zarr", + self.sensitivity_path, mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), + shape=(self.survey.nD, n_cells), + chunks=(self.max_chunk_size, n_cells), ) else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float32) blocks = [] count = 0 @@ -92,7 +99,7 @@ def compute_J(self, m, f=None): du_dmT += df_dmT # - du_dmT = du_dmT.T.reshape((-1, m_size)) + du_dmT = du_dmT.T.reshape((-1, n_cells)) if len(blocks) == 0: blocks = du_dmT @@ -130,7 +137,7 @@ def compute_J(self, m, f=None): if self.store_sensitivities == "disk": del Jmatrix - self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") + self._Jmatrix = da.from_zarr(self.sensitivity_path) else: self._Jmatrix = Jmatrix diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index ae088f62d7..9783df930a 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -3,10 +3,10 @@ from ...potential_fields.base import BasePFSimulation as Sim import os -from dask import delayed, array, config +from dask import delayed, array, compute from dask.diagnostics import ProgressBar -from ..utils import compute_chunk_sizes + import zarr @@ -41,7 +41,7 @@ def residual(self, m, dobs, f=None): return self.dpred(m, f=f) - dobs -def block_compute(sim, rows, components): +def block_compute(sim, rows, components, j_matrix, count): block = [] for row in rows: block.append(sim.evaluate_integral(row, components)) @@ -49,16 +49,14 @@ def block_compute(sim, rows, components): if sim.store_sensitivities == "forward_only": return np.hstack(block) - return np.vstack(block) + values = np.vstack(block) + return storage_formatter(values, count, j_matrix) def storage_formatter( - rows: list[np.ndarray], - device: str, - chunk_format="rows", - sens_name: str = "./sensitivities.zarr", - max_chunk_size: float = 256, - compute: bool = True, + rows: np.ndarray, + count: int, + j_matrix: zarr.Array | None = None, ): """ Format the storage of the sensitivity matrix. @@ -72,42 +70,14 @@ def storage_formatter( :param compute: If True, compute the dask array before returning. """ - if device == "forward_only": - if compute: - return np.hstack(rows) - return array.concatenate(rows) - elif device == "disk": - stack = array.vstack(rows) - # Chunking options - if chunk_format == "row": - config.set({"array.chunk-size": f"{max_chunk_size}MiB"}) - # Autochunking by rows is faster and more memory efficient for - # very large problems sensitivty and forward calculations - stack = stack.rechunk({0: "auto", 1: -1}) - elif chunk_format == "equal": - # Manual chunks for equal number of blocks along rows and columns. - # Optimal for Jvec and Jtvec operations - row_chunk, col_chunk = compute_chunk_sizes(*stack.shape, max_chunk_size) - stack = stack.rechunk((row_chunk, col_chunk)) - else: - # Auto chunking by columns is faster for Inversions - config.set({"array.chunk-size": f"{max_chunk_size}MiB"}) - stack = stack.rechunk({0: -1, 1: "auto"}) - - return array.to_zarr( - stack, sens_name, compute=False, return_stored=False, overwrite=True + if isinstance(j_matrix, zarr.Array): + j_matrix.set_orthogonal_selection( + (np.arange(count, count + rows.shape[0]), slice(None)), + rows.astype(np.float32), ) - else: - if compute: - return np.vstack(rows) + return None - return array.vstack(rows) - - -def set_orthogonal_selection(row, Jmatrix, count): - n_rows = row.shape[0] - Jmatrix[count : count + n_rows, :] = row - return None + return rows def linear_operator(self): @@ -116,9 +86,19 @@ def linear_operator(self): if getattr(self, "model_type", None) == "vector": n_cells *= 3 - if self.store_sensitivities == "disk" and os.path.exists(self.sensitivity_path): - kernel = array.from_zarr(self.sensitivity_path) - return kernel + if self.store_sensitivities == "disk": + + if os.path.exists(self.sensitivity_path): + return array.from_zarr(self.sensitivity_path) + + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, n_cells), + chunks=(self.max_chunk_size, n_cells), + ) + else: + Jmatrix = None n_components = len(self.survey.components) n_blocks = np.ceil( @@ -135,80 +115,46 @@ def linear_operator(self): delayed_compute = delayed(block_compute) rows = [] + count = 0 for block in block_split: if client: - rows.append( - client.submit( - block_compute, - sim, - block, - self.survey.components, - workers=worker, - ) - ) - else: - chunk = delayed_compute(self, block, self.survey.components) - rows.append( - array.from_delayed( - chunk, - dtype=self.sensitivity_dtype, - shape=( - (len(block) * n_components,) - if forward_only - else (len(block) * n_components, n_cells) - ), - ) - ) - - if client: - - if self.store_sensitivities == "disk": - Jmatrix = zarr.open( - self.sensitivity_path, - mode="w", - shape=(self.survey.nD, n_cells), - chunks=(self.max_chunk_size, n_cells), + row = client.submit( + block_compute, + sim, + block, + self.survey.components, + Jmatrix, + count, + workers=worker, ) - count = 0 - future = [] - for row, block in zip(rows, block_split): - row_chunks = block.shape[0] - future.append( - client.submit(set_orthogonal_selection, row, Jmatrix, count) - ) - - count += row_chunks - else: - future = client.submit( - storage_formatter, - rows, - device=self.store_sensitivities, - chunk_format=self.chunk_format, - sens_name=self.sensitivity_path, - max_chunk_size=self.max_chunk_size, - workers=worker, + chunk = delayed_compute(self, block, self.survey.components, Jmatrix, count) + row = array.from_delayed( + chunk, + dtype=self.sensitivity_dtype, + shape=( + (len(block) * n_components,) + if forward_only + else (len(block) * n_components, n_cells) + ), ) + count += block.shape[0] + rows.append(row) - kernel = client.gather(future) - + if client: + kernel = client.gather(rows) else: - with ProgressBar(): - kernel = storage_formatter( - rows, - device=self.store_sensitivities, - chunk_format=self.chunk_format, - sens_name=self.sensitivity_path, - max_chunk_size=self.max_chunk_size, - compute=False, - ).compute() + kernel = compute(rows)[0] if self.store_sensitivities == "disk" and os.path.exists(self.sensitivity_path): - kernel = array.from_zarr(self.sensitivity_path) + return array.from_zarr(self.sensitivity_path) + + if forward_only: + return np.hstack(kernel) - return kernel + return np.vstack(kernel) def compute_J(self, _, f=None): @@ -219,6 +165,7 @@ def compute_J(self, _, f=None): def Jmatrix(self): if getattr(self, "_Jmatrix", None) is None: self._Jmatrix = self.compute_J(self.model) + return self._Jmatrix From 52807fbcb08f042f6f70dba31459d3d37d4ffb0e Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 10 Oct 2025 14:00:35 -0700 Subject: [PATCH 4/6] Clean out array on file --- .../frequency_domain/simulation.py | 5 + .../static/resistivity/simulation.py | 4 +- .../time_domain/simulation.py | 119 +++++++++--------- 3 files changed, 64 insertions(+), 64 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 82458640b3..ffe1c38a35 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,4 +1,6 @@ import gc +import os +import shutil from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero @@ -199,6 +201,9 @@ def compute_J(self, m, f=None): m_size = m.size if self.store_sensitivities == "disk": + if os.path.exists(self.sensitivity_path): + shutil.rmtree(self.sensitivity_path) + Jmatrix = zarr.open( self.sensitivity_path, mode="w", diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 90ae36601e..cefd3d5f2b 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -3,7 +3,7 @@ from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix from .....utils import Zero - +import shutil import os import dask.array as da import numpy as np @@ -56,7 +56,7 @@ def compute_J(self, m, f=None): if self.store_sensitivities == "disk": if os.path.exists(self.sensitivity_path): - return da.from_zarr(self.sensitivity_path) + shutil.rmtree(self.sensitivity_path) Jmatrix = zarr.open( self.sensitivity_path, diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index e1ff8d1e7c..6e8ff2b2a0 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -1,6 +1,5 @@ -import dask -import dask.array import os +import shutil from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero @@ -8,10 +7,9 @@ import numpy as np import scipy.sparse as sp -from dask import array, delayed +from dask import array, delayed, compute +import zarr - -from time import time from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc @@ -92,22 +90,24 @@ def compute_J(self, m, f=None): client, worker = self._get_client_worker() ftype = self._fieldType + "Solution" - sens_name = self.sensitivity_path[:-5] + n_cells = m.size + if self.store_sensitivities == "disk": - rows = array.zeros( - (self.survey.nD, m.size), - chunks=(self.max_chunk_size, m.size), - dtype=np.float32, - ) - Jmatrix = array.to_zarr( - rows, - os.path.join(sens_name + "_1.zarr"), - compute=True, - return_stored=True, - overwrite=True, + + if os.path.exists(self.sensitivity_path): + shutil.rmtree(self.sensitivity_path) + + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, n_cells), + chunks=(self.max_chunk_size, n_cells), ) else: - Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) + Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float64) + + if client: + Jmatrix = client.scatter(Jmatrix, workers=worker) simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times @@ -122,7 +122,7 @@ def compute_J(self, m, f=None): if len(self.survey.source_list) == 1: fields_array = fields_array[:, np.newaxis, :] - times_field_derivs, Jmatrix = compute_field_derivs( + times_field_derivs = compute_field_derivs( self, f, blocks, Jmatrix, fields_array.shape ) @@ -178,6 +178,7 @@ def compute_J(self, m, f=None): field_derivatives, fields_array, time_mask, + Jmatrix, workers=worker, ) ) @@ -192,6 +193,7 @@ def compute_J(self, m, f=None): field_derivatives, fields_array, time_mask, + Jmatrix, ), dtype=np.float32, shape=( @@ -202,29 +204,21 @@ def compute_J(self, m, f=None): ) if client: - j_row_updates = np.vstack(client.gather(future_updates)) - else: - j_row_updates = array.vstack(future_updates).compute() - - if self.store_sensitivities == "disk": - sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" - array.to_zarr( - Jmatrix + j_row_updates, - sens_name, - compute=True, - overwrite=True, - ) - Jmatrix = array.from_zarr(sens_name) + client.gather(future_updates) else: - Jmatrix += j_row_updates + compute(future_updates) for A in Ainv.values(): A.clean() - if self.store_sensitivities == "ram": - self._Jmatrix = np.asarray(Jmatrix) + if self.store_sensitivities == "disk": + del Jmatrix + self._Jmatrix = array.from_zarr(self.sensitivity_path) + else: + if client: + Jmatrix = client.gather(Jmatrix) - self._Jmatrix = Jmatrix + self._Jmatrix = Jmatrix return self._Jmatrix @@ -299,6 +293,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): time_mesh, fields, self.model.size, + Jmatrix, workers=worker, ) ) @@ -313,36 +308,23 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): self.time_mesh, fields, self.model.size, + Jmatrix, ) ) if client: result = client.gather(delayed_chunks) else: - result = dask.compute(delayed_chunks)[0] + result = compute(delayed_chunks)[0] - df_duT = [ - [[[] for _ in block] for block in blocks if len(block) > 0] - for _ in range(self.nT + 1) - ] - j_updates = [] + df_duT = [[[[] for _ in block] for block in blocks] for _ in range(self.nT + 1)] for bb, block in enumerate(result): - j_updates += block[1] - for cc, chunk in enumerate(block[0]): + for cc, chunk in enumerate(block): for ind, time_block in enumerate(chunk): df_duT[ind][bb][cc] = time_block - j_updates = sp.vstack(j_updates) - - if len(j_updates.data) > 0: - Jmatrix += j_updates - if self.store_sensitivities == "disk": - sens_name = self.sensitivity_path[:-5] + f"_{time() % 2}.zarr" - array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) - Jmatrix = array.from_zarr(sens_name) - - return df_duT, Jmatrix + return df_duT def get_field_deriv_block( @@ -398,12 +380,10 @@ def get_field_deriv_block( def block_deriv( - n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape + n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape, Jmatrix ): """Compute derivatives for sources and receivers in a block""" df_duT = [] - j_updates = [] - for indices, arrays in chunks: j_update = 0.0 source = source_list[indices[0]] @@ -438,10 +418,19 @@ def block_deriv( else: j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - j_updates.append(j_update) + if isinstance(Jmatrix, zarr.Array): + j_slice = Jmatrix.get_orthogonal_selection((arrays[1], slice(None))) + + Jmatrix.set_orthogonal_selection( + (arrays[1], slice(None)), + j_slice + j_update, + ) + else: + Jmatrix[arrays[1], :] + j_update + df_duT.append(time_derivs) - return df_duT, j_updates + return df_duT def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs): @@ -465,6 +454,7 @@ def compute_rows( field_derivs, fields, time_mask, + Jmatrix, ): """ Compute the rows of the sensitivity matrix for a given source and receiver. @@ -505,9 +495,14 @@ def compute_rows( row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( np.float32 ) - rows.append(row_block) - return np.vstack(rows) + if isinstance(Jmatrix, zarr.Array): + j_slice = Jmatrix.get_orthogonal_selection((ind_array[1], slice(None))) + Jmatrix.set_orthogonal_selection( + (ind_array[1], slice(None)), j_slice + row_block + ) + else: + Jmatrix[ind_array[1], :] += row_block def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): @@ -583,7 +578,7 @@ def dpred(self, m=None, f=None): if client: result = client.gather(delayed_chunks) else: - result = dask.compute(delayed_chunks)[0] + result = compute(delayed_chunks)[0] return np.hstack(result) From 6095e2449dc9ba68375c638938c68349ad7f8cd5 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sat, 11 Oct 2025 09:22:43 -0700 Subject: [PATCH 5/6] Chunk the zarr by block size to avoid cross-talks --- .../frequency_domain/simulation.py | 59 +++++++++++-------- .../static/resistivity/simulation.py | 2 +- .../time_domain/simulation.py | 32 +++++----- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index ffe1c38a35..023b1ec605 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -52,7 +52,9 @@ def receiver_derivs(survey, mesh, fields, blocks): return field_derivatives -def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address): +def compute_rows( + simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address, Jmatrix +): """ Evaluate the sensitivities for the block or data """ @@ -94,7 +96,14 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address if not isinstance(deriv_m, Zero): du_dmT += deriv_m - return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + + if isinstance(Jmatrix, zarr.Array): + Jmatrix.set_orthogonal_selection((address[1][1], slice(None)), values) + else: + Jmatrix[address[1][1], :] = values + + return None def getSourceTerm(self, freq, source=None): @@ -197,10 +206,21 @@ def compute_J(self, m, f=None): "Consider creating one misfit per frequency." ) + client, worker = self._get_client_worker() + A_i = list(Ainv.values())[0] m_size = m.size + compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) + blocks = get_parallel_blocks( + self.survey.source_list, compute_row_size, optimize=True + ) if self.store_sensitivities == "disk": + + chunk_size = np.median( + [np.sum([len(chunk[1][1]) for chunk in block]) for block in blocks] + ).astype(int) + if os.path.exists(self.sensitivity_path): shutil.rmtree(self.sensitivity_path) @@ -208,20 +228,17 @@ def compute_J(self, m, f=None): self.sensitivity_path, mode="w", shape=(self.survey.nD, m_size), - chunks=(self.max_chunk_size, m_size), + chunks=(chunk_size, m_size), ) else: Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) - blocks = get_parallel_blocks( - self.survey.source_list, compute_row_size, optimize=False - ) + if client: + Jmatrix = client.scatter(Jmatrix, workers=worker) + fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - client, worker = self._get_client_worker() - if client: fields_array = client.scatter(f[:, self._solutionType], workers=worker) fields = client.scatter(f, workers=worker) @@ -309,10 +326,9 @@ def parallel_block_compute( ATinvdf_duT = client.scatter(ATinvdf_duT, workers=worker) else: ATinvdf_duT = delayed(ATinvdf_duT) + count = 0 - rows = [] block_delayed = [] - for address, dfduT in zip(addresses, blocks_receiver_derivs): n_cols = dfduT.shape[1] n_rows = address[1][2] @@ -320,18 +336,19 @@ def parallel_block_compute( if client: block_delayed.append( client.submit( - eval_block, + compute_rows, simulation, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), fields_array, address, + Jmatrix, workers=worker, ) ) else: - delayed_eval = delayed(eval_block) + delayed_eval = delayed(compute_rows) block_delayed.append( array.from_delayed( delayed_eval( @@ -341,32 +358,22 @@ def parallel_block_compute( Zero(), fields_array, address, + Jmatrix, ), dtype=np.float32, shape=(n_rows, m_size), ) ) count += n_cols - rows += address[1][1].tolist() - - indices = np.hstack(rows) if client: - block_delayed = client.gather(block_delayed) - block = np.vstack(block_delayed) - else: - block = compute(array.vstack(block_delayed))[0] - - if isinstance(Jmatrix, zarr.Array): - Jmatrix.set_orthogonal_selection((indices, slice(None)), block) + client.gather(block_delayed) else: - # Dask process to compute row and store - Jmatrix[indices, :] = block + compute(block_delayed) return Jmatrix -Sim.parallel_block_compute = parallel_block_compute Sim.compute_J = compute_J Sim.getJtJdiag = getJtJdiag Sim.Jvec = Jvec diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index cefd3d5f2b..e3f6e5e29e 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -62,7 +62,7 @@ def compute_J(self, m, f=None): self.sensitivity_path, mode="w", shape=(self.survey.nD, n_cells), - chunks=(self.max_chunk_size, n_cells), + chunks=(row_chunks, n_cells), ) else: Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float32) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 6e8ff2b2a0..5986f57ffa 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -92,8 +92,23 @@ def compute_J(self, m, f=None): ftype = self._fieldType + "Solution" n_cells = m.size - if self.store_sensitivities == "disk": + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + data_times = self.survey.source_list[0].receiver_list[0].times + compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) + blocks = get_parallel_blocks( + self.survey.source_list, + compute_row_size, + thread_count=self.n_threads(client=client, worker=worker), + ) + fields_array = f[:, ftype, :] + if len(self.survey.source_list) == 1: + fields_array = fields_array[:, np.newaxis, :] + + if self.store_sensitivities == "disk": + chunk_size = np.median( + [np.sum([len(chunk[1][1]) for chunk in block]) for block in blocks] + ).astype(int) if os.path.exists(self.sensitivity_path): shutil.rmtree(self.sensitivity_path) @@ -101,7 +116,7 @@ def compute_J(self, m, f=None): self.sensitivity_path, mode="w", shape=(self.survey.nD, n_cells), - chunks=(self.max_chunk_size, n_cells), + chunks=(chunk_size, n_cells), ) else: Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float64) @@ -109,19 +124,6 @@ def compute_J(self, m, f=None): if client: Jmatrix = client.scatter(Jmatrix, workers=worker) - simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - data_times = self.survey.source_list[0].receiver_list[0].times - compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) - blocks = get_parallel_blocks( - self.survey.source_list, - compute_row_size, - thread_count=self.n_threads(client=client, worker=worker), - ) - fields_array = f[:, ftype, :] - - if len(self.survey.source_list) == 1: - fields_array = fields_array[:, np.newaxis, :] - times_field_derivs = compute_field_derivs( self, f, blocks, Jmatrix, fields_array.shape ) From 523df6ad8f2cf115fae9d66d818401a76dbb69f6 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 14 Oct 2025 12:28:24 -0700 Subject: [PATCH 6/6] Update docstrings --- simpeg/dask/potential_fields/base.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 9783df930a..67dd07f188 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -62,12 +62,11 @@ def storage_formatter( Format the storage of the sensitivity matrix. :param rows: List of dask arrays representing blocks of the sensitivity matrix. - :param device: Storage option, either "forward_only", "disk", or "memory". - :param chunk_format: Chunking format for disk storage, either "row", "equal - or "auto". - :param sens_name: File path to store the sensitivity matrix if device is "disk". - :param max_chunk_size: Maximum chunk size in MiB for disk storage. - :param compute: If True, compute the dask array before returning. + :param count: Current row count offset. + :param j_matrix: Zarr array to store the sensitivity matrix on disk, if applicable + + :return: If j_matrix is provided, returns None after storing the rows; otherwise, + returns the stacked rows as a NumPy array. """ if isinstance(j_matrix, zarr.Array):