diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 32deed7f7c..023b1ec605 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,4 +1,6 @@ import gc +import os +import shutil from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero @@ -50,7 +52,9 @@ def receiver_derivs(survey, mesh, fields, blocks): return field_derivatives -def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address): +def compute_rows( + simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address, Jmatrix +): """ Evaluate the sensitivities for the block or data """ @@ -92,7 +96,14 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address if not isinstance(deriv_m, Zero): du_dmT += deriv_m - return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + + if isinstance(Jmatrix, zarr.Array): + Jmatrix.set_orthogonal_selection((address[1][1], slice(None)), values) + else: + Jmatrix[address[1][1], :] = values + + return None def getSourceTerm(self, freq, source=None): @@ -195,28 +206,39 @@ def compute_J(self, m, f=None): "Consider creating one misfit per frequency." ) + client, worker = self._get_client_worker() + A_i = list(Ainv.values())[0] m_size = m.size + compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) + blocks = get_parallel_blocks( + self.survey.source_list, compute_row_size, optimize=True + ) if self.store_sensitivities == "disk": + + chunk_size = np.median( + [np.sum([len(chunk[1][1]) for chunk in block]) for block in blocks] + ).astype(int) + + if os.path.exists(self.sensitivity_path): + shutil.rmtree(self.sensitivity_path) + Jmatrix = zarr.open( self.sensitivity_path, mode="w", shape=(self.survey.nD, m_size), - chunks=(self.max_chunk_size, m_size), + chunks=(chunk_size, m_size), ) else: Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) - blocks = get_parallel_blocks( - self.survey.source_list, compute_row_size, optimize=False - ) + if client: + Jmatrix = client.scatter(Jmatrix, workers=worker) + fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - client, worker = self._get_client_worker() - if client: fields_array = client.scatter(f[:, self._solutionType], workers=worker) fields = client.scatter(f, workers=worker) @@ -270,7 +292,6 @@ def compute_J(self, m, f=None): addresses_chunks, client, worker, - store_sensitivities=self.store_sensitivities, ) for A in Ainv.values(): @@ -295,7 +316,6 @@ def parallel_block_compute( addresses, client, worker=None, - store_sensitivities="disk", ): m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() @@ -306,10 +326,9 @@ def parallel_block_compute( ATinvdf_duT = client.scatter(ATinvdf_duT, workers=worker) else: ATinvdf_duT = delayed(ATinvdf_duT) + count = 0 - rows = [] block_delayed = [] - for address, dfduT in zip(addresses, blocks_receiver_derivs): n_cols = dfduT.shape[1] n_rows = address[1][2] @@ -317,18 +336,19 @@ def parallel_block_compute( if client: block_delayed.append( client.submit( - eval_block, + compute_rows, simulation, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), fields_array, address, + Jmatrix, workers=worker, ) ) else: - delayed_eval = delayed(eval_block) + delayed_eval = delayed(compute_rows) block_delayed.append( array.from_delayed( delayed_eval( @@ -338,35 +358,22 @@ def parallel_block_compute( Zero(), fields_array, address, + Jmatrix, ), dtype=np.float32, shape=(n_rows, m_size), ) ) count += n_cols - rows += address[1][1].tolist() - - indices = np.hstack(rows) if client: - block_delayed = client.gather(block_delayed) - block = np.vstack(block_delayed) - else: - block = compute(array.vstack(block_delayed))[0] - - if store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (indices, slice(None)), - block, - ) + client.gather(block_delayed) else: - # Dask process to compute row and store - Jmatrix[indices, :] = block + compute(block_delayed) return Jmatrix -Sim.parallel_block_compute = parallel_block_compute Sim.compute_J = compute_J Sim.getJtJdiag = getJtJdiag Sim.Jvec = Jvec diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 1a49889a1b..e3f6e5e29e 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -3,7 +3,8 @@ from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix from .....utils import Zero - +import shutil +import os import dask.array as da import numpy as np from scipy import sparse as sp @@ -42,23 +43,29 @@ def compute_J(self, m, f=None): f, Ainv = self.fields(m=m, return_Ainv=True) - m_size = m.size + n_cells = m.size row_chunks = int( np.ceil( float(self.survey.nD) - / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) + / np.ceil( + float(n_cells) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size + ) ) ) if self.store_sensitivities == "disk": + + if os.path.exists(self.sensitivity_path): + shutil.rmtree(self.sensitivity_path) + Jmatrix = zarr.open( - self.sensitivity_path + "J.zarr", + self.sensitivity_path, mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), + shape=(self.survey.nD, n_cells), + chunks=(row_chunks, n_cells), ) else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float32) blocks = [] count = 0 @@ -92,7 +99,7 @@ def compute_J(self, m, f=None): du_dmT += df_dmT # - du_dmT = du_dmT.T.reshape((-1, m_size)) + du_dmT = du_dmT.T.reshape((-1, n_cells)) if len(blocks) == 0: blocks = du_dmT @@ -130,7 +137,7 @@ def compute_J(self, m, f=None): if self.store_sensitivities == "disk": del Jmatrix - self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") + self._Jmatrix = da.from_zarr(self.sensitivity_path) else: self._Jmatrix = Jmatrix diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index e1ff8d1e7c..5986f57ffa 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -1,6 +1,5 @@ -import dask -import dask.array import os +import shutil from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero @@ -8,10 +7,9 @@ import numpy as np import scipy.sparse as sp -from dask import array, delayed +from dask import array, delayed, compute +import zarr - -from time import time from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc @@ -92,22 +90,7 @@ def compute_J(self, m, f=None): client, worker = self._get_client_worker() ftype = self._fieldType + "Solution" - sens_name = self.sensitivity_path[:-5] - if self.store_sensitivities == "disk": - rows = array.zeros( - (self.survey.nD, m.size), - chunks=(self.max_chunk_size, m.size), - dtype=np.float32, - ) - Jmatrix = array.to_zarr( - rows, - os.path.join(sens_name + "_1.zarr"), - compute=True, - return_stored=True, - overwrite=True, - ) - else: - Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) + n_cells = m.size simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times @@ -122,7 +105,26 @@ def compute_J(self, m, f=None): if len(self.survey.source_list) == 1: fields_array = fields_array[:, np.newaxis, :] - times_field_derivs, Jmatrix = compute_field_derivs( + if self.store_sensitivities == "disk": + chunk_size = np.median( + [np.sum([len(chunk[1][1]) for chunk in block]) for block in blocks] + ).astype(int) + if os.path.exists(self.sensitivity_path): + shutil.rmtree(self.sensitivity_path) + + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, n_cells), + chunks=(chunk_size, n_cells), + ) + else: + Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float64) + + if client: + Jmatrix = client.scatter(Jmatrix, workers=worker) + + times_field_derivs = compute_field_derivs( self, f, blocks, Jmatrix, fields_array.shape ) @@ -178,6 +180,7 @@ def compute_J(self, m, f=None): field_derivatives, fields_array, time_mask, + Jmatrix, workers=worker, ) ) @@ -192,6 +195,7 @@ def compute_J(self, m, f=None): field_derivatives, fields_array, time_mask, + Jmatrix, ), dtype=np.float32, shape=( @@ -202,29 +206,21 @@ def compute_J(self, m, f=None): ) if client: - j_row_updates = np.vstack(client.gather(future_updates)) - else: - j_row_updates = array.vstack(future_updates).compute() - - if self.store_sensitivities == "disk": - sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" - array.to_zarr( - Jmatrix + j_row_updates, - sens_name, - compute=True, - overwrite=True, - ) - Jmatrix = array.from_zarr(sens_name) + client.gather(future_updates) else: - Jmatrix += j_row_updates + compute(future_updates) for A in Ainv.values(): A.clean() - if self.store_sensitivities == "ram": - self._Jmatrix = np.asarray(Jmatrix) + if self.store_sensitivities == "disk": + del Jmatrix + self._Jmatrix = array.from_zarr(self.sensitivity_path) + else: + if client: + Jmatrix = client.gather(Jmatrix) - self._Jmatrix = Jmatrix + self._Jmatrix = Jmatrix return self._Jmatrix @@ -299,6 +295,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): time_mesh, fields, self.model.size, + Jmatrix, workers=worker, ) ) @@ -313,36 +310,23 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): self.time_mesh, fields, self.model.size, + Jmatrix, ) ) if client: result = client.gather(delayed_chunks) else: - result = dask.compute(delayed_chunks)[0] + result = compute(delayed_chunks)[0] - df_duT = [ - [[[] for _ in block] for block in blocks if len(block) > 0] - for _ in range(self.nT + 1) - ] - j_updates = [] + df_duT = [[[[] for _ in block] for block in blocks] for _ in range(self.nT + 1)] for bb, block in enumerate(result): - j_updates += block[1] - for cc, chunk in enumerate(block[0]): + for cc, chunk in enumerate(block): for ind, time_block in enumerate(chunk): df_duT[ind][bb][cc] = time_block - j_updates = sp.vstack(j_updates) - - if len(j_updates.data) > 0: - Jmatrix += j_updates - if self.store_sensitivities == "disk": - sens_name = self.sensitivity_path[:-5] + f"_{time() % 2}.zarr" - array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) - Jmatrix = array.from_zarr(sens_name) - - return df_duT, Jmatrix + return df_duT def get_field_deriv_block( @@ -398,12 +382,10 @@ def get_field_deriv_block( def block_deriv( - n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape + n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape, Jmatrix ): """Compute derivatives for sources and receivers in a block""" df_duT = [] - j_updates = [] - for indices, arrays in chunks: j_update = 0.0 source = source_list[indices[0]] @@ -438,10 +420,19 @@ def block_deriv( else: j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - j_updates.append(j_update) + if isinstance(Jmatrix, zarr.Array): + j_slice = Jmatrix.get_orthogonal_selection((arrays[1], slice(None))) + + Jmatrix.set_orthogonal_selection( + (arrays[1], slice(None)), + j_slice + j_update, + ) + else: + Jmatrix[arrays[1], :] + j_update + df_duT.append(time_derivs) - return df_duT, j_updates + return df_duT def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs): @@ -465,6 +456,7 @@ def compute_rows( field_derivs, fields, time_mask, + Jmatrix, ): """ Compute the rows of the sensitivity matrix for a given source and receiver. @@ -505,9 +497,14 @@ def compute_rows( row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( np.float32 ) - rows.append(row_block) - return np.vstack(rows) + if isinstance(Jmatrix, zarr.Array): + j_slice = Jmatrix.get_orthogonal_selection((ind_array[1], slice(None))) + Jmatrix.set_orthogonal_selection( + (ind_array[1], slice(None)), j_slice + row_block + ) + else: + Jmatrix[ind_array[1], :] += row_block def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): @@ -583,7 +580,7 @@ def dpred(self, m=None, f=None): if client: result = client.gather(delayed_chunks) else: - result = dask.compute(delayed_chunks)[0] + result = compute(delayed_chunks)[0] return np.hstack(result) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 2229980763..67dd07f188 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -3,9 +3,11 @@ from ...potential_fields.base import BasePFSimulation as Sim import os -from dask import delayed, array, config +from dask import delayed, array, compute + from dask.diagnostics import ProgressBar -from ..utils import compute_chunk_sizes + +import zarr _chunk_format = "row" @@ -27,16 +29,19 @@ def chunk_format(self, other): def dpred(self, m=None, f=None): if m is not None: self.model = m - if f is not None: - return f - return self.fields(self.model) + if f is None: + f = self.fields(self.model) + + if isinstance(f, array.Array): + return np.asarray(f) + return f def residual(self, m, dobs, f=None): return self.dpred(m, f=f) - dobs -def block_compute(sim, rows, components): +def block_compute(sim, rows, components, j_matrix, count): block = [] for row in rows: block.append(sim.evaluate_integral(row, components)) @@ -44,7 +49,34 @@ def block_compute(sim, rows, components): if sim.store_sensitivities == "forward_only": return np.hstack(block) - return np.vstack(block) + values = np.vstack(block) + return storage_formatter(values, count, j_matrix) + + +def storage_formatter( + rows: np.ndarray, + count: int, + j_matrix: zarr.Array | None = None, +): + """ + Format the storage of the sensitivity matrix. + + :param rows: List of dask arrays representing blocks of the sensitivity matrix. + :param count: Current row count offset. + :param j_matrix: Zarr array to store the sensitivity matrix on disk, if applicable + + :return: If j_matrix is provided, returns None after storing the rows; otherwise, + returns the stacked rows as a NumPy array. + """ + + if isinstance(j_matrix, zarr.Array): + j_matrix.set_orthogonal_selection( + (np.arange(count, count + rows.shape[0]), slice(None)), + rows.astype(np.float32), + ) + return None + + return rows def linear_operator(self): @@ -53,6 +85,20 @@ def linear_operator(self): if getattr(self, "model_type", None) == "vector": n_cells *= 3 + if self.store_sensitivities == "disk": + + if os.path.exists(self.sensitivity_path): + return array.from_zarr(self.sensitivity_path) + + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, n_cells), + chunks=(self.max_chunk_size, n_cells), + ) + else: + Jmatrix = None + n_components = len(self.survey.components) n_blocks = np.ceil( (n_cells * n_components * self.survey.receiver_locations.shape[0] * 8.0 * 1e-6) @@ -68,83 +114,46 @@ def linear_operator(self): delayed_compute = delayed(block_compute) rows = [] + count = 0 for block in block_split: if client: - rows.append( - client.submit( - block_compute, - sim, - block, - self.survey.components, - workers=worker, - ) + row = client.submit( + block_compute, + sim, + block, + self.survey.components, + Jmatrix, + count, + workers=worker, ) + else: - chunk = delayed_compute(self, block, self.survey.components) - rows.append( - array.from_delayed( - chunk, - dtype=self.sensitivity_dtype, - shape=( - (len(block) * n_components,) - if forward_only - else (len(block) * n_components, n_cells) - ), - ) + chunk = delayed_compute(self, block, self.survey.components, Jmatrix, count) + row = array.from_delayed( + chunk, + dtype=self.sensitivity_dtype, + shape=( + (len(block) * n_components,) + if forward_only + else (len(block) * n_components, n_cells) + ), ) + count += block.shape[0] + rows.append(row) if client: - if forward_only: - return np.hstack(client.gather(rows)) - return np.vstack(client.gather(rows)) - - if forward_only: - stack = array.concatenate(rows) + kernel = client.gather(rows) else: - stack = array.vstack(rows) - # Chunking options - if self.chunk_format == "row": - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - # Autochunking by rows is faster and more memory efficient for - # very large problems sensitivty and forward calculations - stack = stack.rechunk({0: "auto", 1: -1}) - elif self.chunk_format == "equal": - # Manual chunks for equal number of blocks along rows and columns. - # Optimal for Jvec and Jtvec operations - row_chunk, col_chunk = compute_chunk_sizes( - *stack.shape, self.max_chunk_size - ) - stack = stack.rechunk((row_chunk, col_chunk)) - else: - # Auto chunking by columns is faster for Inversions - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - stack = stack.rechunk({0: -1, 1: "auto"}) - - if self.store_sensitivities == "disk": - sens_name = os.path.join(self.sensitivity_path, "sensitivity.zarr") - if os.path.exists(sens_name): - kernel = array.from_zarr(sens_name) - if np.all( - np.r_[ - np.any(np.r_[kernel.chunks[0]] == stack.chunks[0]), - np.any(np.r_[kernel.chunks[1]] == stack.chunks[1]), - np.r_[kernel.shape] == np.r_[stack.shape], - ] - ): - # Check that loaded kernel matches supplied data and mesh - print("Zarr file detected with same shape and chunksize ... re-loading") - return kernel - - print("Writing Zarr file to disk") with ProgressBar(): - print("Saving kernel to zarr: " + sens_name) - kernel = array.to_zarr( - stack, sens_name, compute=True, return_stored=True, overwrite=True - ) + kernel = compute(rows)[0] + + if self.store_sensitivities == "disk" and os.path.exists(self.sensitivity_path): + return array.from_zarr(self.sensitivity_path) - with ProgressBar(): - kernel = stack.compute() - return kernel + if forward_only: + return np.hstack(kernel) + + return np.vstack(kernel) def compute_J(self, _, f=None): @@ -155,6 +164,7 @@ def compute_J(self, _, f=None): def Jmatrix(self): if getattr(self, "_Jmatrix", None) is None: self._Jmatrix = self.compute_J(self.model) + return self._Jmatrix