From 06d0a14b803dab3125aaf957e321d6946b68a2f6 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 24 Sep 2025 16:01:39 -0400 Subject: [PATCH 1/3] REwork fetching client and worker on simulation --- .../frequency_domain/simulation.py | 27 ++-- .../static/resistivity/simulation.py | 21 +-- .../time_domain/simulation.py | 51 ++++--- simpeg/dask/potential_fields/base.py | 11 +- simpeg/dask/simulation.py | 135 +++++++++++------- 5 files changed, 127 insertions(+), 118 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 09a50e61a2..32deed7f7c 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -7,7 +7,6 @@ import scipy.sparse as sp from dask import array, compute, delayed -from dask.distributed import get_client from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr @@ -104,20 +103,17 @@ def getSourceTerm(self, freq, source=None): if source is None: - try: - client = get_client() - sim = client.scatter(self, workers=self.worker) - except ValueError: - client = None - sim = self - + client, worker = self._get_client_worker() source_list = self.survey.get_sources_by_frequency(freq) source_blocks = np.array_split( - np.arange(len(source_list)), self.n_threads(client=client) + np.arange(len(source_list)), self.n_threads(client=client, worker=worker) ) if client: - source_list = client.scatter(source_list, workers=self.worker) + sim = client.scatter(self, workers=self.worker) + source_list = client.scatter(source_list, workers=worker) + else: + sim = self block_compute = [] @@ -127,9 +123,7 @@ def getSourceTerm(self, freq, source=None): if client: block_compute.append( - client.submit( - source_eval, sim, source_list, block, workers=self.worker - ) + client.submit(source_eval, sim, source_list, block, workers=worker) ) else: block_compute.append(source_eval(sim, source_list, block)) @@ -221,12 +215,7 @@ def compute_J(self, m, f=None): fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - try: - client = get_client() - worker = self.worker - except ValueError: - client = None - worker = None + client, worker = self._get_client_worker() if client: fields_array = client.scatter(f[:, self._solutionType], workers=worker) diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 3605218046..1a49889a1b 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -3,7 +3,7 @@ from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix from .....utils import Zero -from dask.distributed import get_client + import dask.array as da import numpy as np from scipy import sparse as sp @@ -163,22 +163,23 @@ def getSourceTerm(self): source_list = self.survey.source_list indices = np.arange(len(source_list)) - try: - client = get_client() - sim = client.scatter(self, workers=self.worker) - future_list = client.scatter(source_list, workers=self.worker) - indices = np.array_split(indices, self.n_threads(client=client)) + client, worker = self._get_client_worker() + + if client: + sim = client.scatter(self, workers=worker) + future_list = client.scatter(source_list, workers=worker) + indices = np.array_split( + indices, self.n_threads(client=client, worker=worker) + ) blocks = [] for ind in indices: blocks.append( - client.submit( - source_eval, sim, future_list, ind, workers=self.worker - ) + client.submit(source_eval, sim, future_list, ind, workers=worker) ) blocks = sp.hstack(client.gather(blocks)) - except ValueError: + else: blocks = source_eval(self, source_list, indices) self._q = blocks diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 19315ff3b0..e1ff8d1e7c 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -9,7 +9,7 @@ import numpy as np import scipy.sparse as sp from dask import array, delayed -from dask.distributed import get_client + from time import time from simpeg.dask.utils import get_parallel_blocks @@ -89,10 +89,7 @@ def compute_J(self, m, f=None): if f is None: f, Ainv = self.fields(m=m, return_Ainv=True) - try: - client = get_client() - except ValueError: - client = None + client, worker = self._get_client_worker() ftype = self._fieldType + "Solution" sens_name = self.sensitivity_path[:-5] @@ -118,7 +115,7 @@ def compute_J(self, m, f=None): blocks = get_parallel_blocks( self.survey.source_list, compute_row_size, - thread_count=self.n_threads(client=client), + thread_count=self.n_threads(client=client, worker=worker), ) fields_array = f[:, ftype, :] @@ -126,14 +123,14 @@ def compute_J(self, m, f=None): fields_array = fields_array[:, np.newaxis, :] times_field_derivs, Jmatrix = compute_field_derivs( - self, f, blocks, Jmatrix, fields_array.shape, client + self, f, blocks, Jmatrix, fields_array.shape ) ATinv_df_duT_v = [[] for _ in blocks] if client: - fields_array = client.scatter(fields_array, workers=self.worker) - sim = client.scatter(self, workers=self.worker) + fields_array = client.scatter(fields_array, workers=worker) + sim = client.scatter(self, workers=worker) else: delayed_compute_rows = delayed(compute_rows) sim = self @@ -161,7 +158,7 @@ def compute_J(self, m, f=None): ) if client: - field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker) + field_derivatives = client.scatter(ATinv_df_duT_v, workers=worker) else: field_derivatives = ATinv_df_duT_v @@ -181,7 +178,7 @@ def compute_J(self, m, f=None): field_derivatives, fields_array, time_mask, - workers=self.worker, + workers=worker, ) ) else: @@ -267,17 +264,19 @@ def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): return np.hstack(data) -def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): +def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): """ Compute the derivative of the fields """ delayed_chunks = [] + client, worker = self._get_client_worker() + if client: - mesh = client.scatter(self.mesh, workers=self.worker) - time_mesh = client.scatter(self.time_mesh, workers=self.worker) - fields = client.scatter(fields, workers=self.worker) - source_list = client.scatter(self.survey.source_list, workers=self.worker) + mesh = client.scatter(self.mesh, workers=worker) + time_mesh = client.scatter(self.time_mesh, workers=worker) + fields = client.scatter(fields, workers=worker) + source_list = client.scatter(self.survey.source_list, workers=worker) else: mesh = self.mesh time_mesh = self.time_mesh @@ -300,7 +299,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): time_mesh, fields, self.model.size, - workers=self.worker, + workers=worker, ) ) else: @@ -537,10 +536,7 @@ def dpred(self, m=None, f=None): "simulation.survey = survey" ) - try: - client = get_client() - except ValueError: - client = None + client, worker = self._get_client_worker() if f is None: f = self.fields(m) @@ -548,13 +544,14 @@ def dpred(self, m=None, f=None): delayed_chunks = [] source_block = np.array_split( - np.arange(len(self.survey.source_list)), self.n_threads(client=client) + np.arange(len(self.survey.source_list)), + self.n_threads(client=client, worker=worker), ) if client: - mesh = client.scatter(self.mesh, workers=self.worker) - time_mesh = client.scatter(self.time_mesh, workers=self.worker) - fields = client.scatter(f, workers=self.worker) - source_list = client.scatter(self.survey.source_list, workers=self.worker) + mesh = client.scatter(self.mesh, workers=worker) + time_mesh = client.scatter(self.time_mesh, workers=worker) + fields = client.scatter(f, workers=worker) + source_list = client.scatter(self.survey.source_list, workers=worker) else: mesh = self.mesh time_mesh = self.time_mesh @@ -575,7 +572,7 @@ def dpred(self, m=None, f=None): mesh, time_mesh, fields, - workers=self.worker, + workers=worker, ) ) else: diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 1a9f2592bc..2229980763 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -1,7 +1,7 @@ import numpy as np from ...potential_fields.base import BasePFSimulation as Sim -from dask.distributed import get_client + import os from dask import delayed, array, config from dask.diagnostics import ProgressBar @@ -60,13 +60,10 @@ def linear_operator(self): ) block_split = np.array_split(self.survey.receiver_locations, n_blocks) - try: - client = get_client() - except ValueError: - client = None + client, worker = self._get_client_worker() if client: - sim = client.scatter(self, workers=self.worker) + sim = client.scatter(self, workers=worker) else: delayed_compute = delayed(block_compute) @@ -79,7 +76,7 @@ def linear_operator(self): sim, block, self.survey.components, - workers=self.worker, + workers=worker, ) ) else: diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 85ccae00f4..3ad3aaf759 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -1,6 +1,7 @@ from ..simulation import BaseSimulation as Sim from dask import array +from dask.distributed import get_client import numpy as np from multiprocessing import cpu_count @@ -63,28 +64,28 @@ def getJtJdiag(self, m, W=None, f=None): Sim.getJtJdiag = getJtJdiag -def __init__( - self, - survey=None, - sensitivity_path="./sensitivity/", - counter=None, - verbose=False, - chunk_format="row", - max_ram=16, - max_chunk_size=128, - **kwargs, -): - _old_init( - self, - survey=survey, - sensitivity_path=sensitivity_path, - counter=counter, - verbose=verbose, - **kwargs, - ) - self.chunk_format = chunk_format - self.max_ram = max_ram - self.max_chunk_size = max_chunk_size +# def __init__( +# self, +# survey=None, +# sensitivity_path="./sensitivity/", +# counter=None, +# verbose=False, +# chunk_format="row", +# max_ram=16, +# max_chunk_size=128, +# **kwargs, +# ): +# _old_init( +# self, +# survey=survey, +# sensitivity_path=sensitivity_path, +# counter=counter, +# verbose=verbose, +# **kwargs, +# ) +# self.chunk_format = chunk_format +# self.max_ram = max_ram +# self.max_chunk_size = max_chunk_size def Jvec(self, m, v, **_): @@ -132,13 +133,17 @@ def Jmatrix(self): Sim.Jmatrix = Jmatrix -def n_threads(self, client=None): +def n_threads(self, client=None, worker=None): """ Number of threads used by Dask """ if getattr(self, "_n_threads", None) is None: if client: - self._n_threads = client.nthreads()[self.worker[0]] + n_threads = client.nthreads() + if not worker: + self._n_threads = list(n_threads.values())[0] + else: + self._n_threads = n_threads[self.worker[0]] else: self._n_threads = cpu_count() @@ -148,38 +153,58 @@ def n_threads(self, client=None): Sim.n_threads = n_threads -# TODO: Make dpred parallel -def dpred(self, m=None, f=None): - r"""Predicted data for the model provided. - - Parameters - ---------- - m : (n_param,) numpy.ndarray - The model parameters. - f : simpeg.fields.Fields, optional - If provided, will be used to compute the predicted data - without recalculating the fields. - - Returns - ------- - (n_data, ) numpy.ndarray - The predicted data vector. +def _get_client_worker(self) -> tuple: """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) + Get the Dask client and worker if they exist. + """ + try: + client = get_client() + except ValueError: + client = None + + try: + worker = self.worker + except AttributeError: + worker = None - if f is None: - if m is None: - m = self.model + return client, worker - f = self.fields(m) - data = Data(self.survey) - for src in self.survey.source_list: - for rx in src.receiver_list: - data[src, rx] = rx.eval(src, self.mesh, f) - return mkvc(data) +Sim._get_client_worker = _get_client_worker + + +# TODO: Make dpred parallel +# def dpred(self, m=None, f=None): +# r"""Predicted data for the model provided. +# +# Parameters +# ---------- +# m : (n_param,) numpy.ndarray +# The model parameters. +# f : simpeg.fields.Fields, optional +# If provided, will be used to compute the predicted data +# without recalculating the fields. +# +# Returns +# ------- +# (n_data, ) numpy.ndarray +# The predicted data vector. +# """ +# if self.survey is None: +# raise AttributeError( +# "The survey has not yet been set and is required to compute " +# "data. Please set the survey for the simulation: " +# "simulation.survey = survey" +# ) +# +# if f is None: +# if m is None: +# m = self.model +# +# f = self.fields(m) +# +# data = Data(self.survey) +# for src in self.survey.source_list: +# for rx in src.receiver_list: +# data[src, rx] = rx.eval(src, self.mesh, f) +# return mkvc(data) From e7731b8e82e1db7a0fb441df0525596652e4e4ed Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 25 Sep 2025 11:25:19 -0400 Subject: [PATCH 2/3] Improve checks on workers. --- simpeg/dask/objective_function.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index f2a6497a51..a2898e36d7 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -7,7 +7,7 @@ from typing import Callable import numpy as np -from dask.distributed import Client, Future +from dask.distributed import Client, Future, get_client from ..data_misfit import L2DataMisfit from simpeg.utils import validate_list_of_types @@ -145,9 +145,7 @@ def _validate_type_or_future_of_type( ): if workers is None: - workers = [ - (worker.worker_address,) for worker in client.cluster.workers.values() - ] + workers = [(worker,) for worker in client.nthreads()] objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True @@ -262,6 +260,9 @@ def client(self): @client.setter def client(self, client): + if client is None: + client = get_client() + if not isinstance(client, Client): raise TypeError("client must be a dask.distributed.Client") @@ -279,6 +280,23 @@ def workers(self, workers): if not isinstance(workers, list | type(None)): raise TypeError("workers must be a list of strings") + available_workers = [(worker,) for worker in self.client.nthreads()] + + if workers is None: + workers = available_workers + + if not isinstance(workers, list) or not all( + isinstance(w, tuple) for w in workers + ): + raise TypeError("Workers must be a list of tuple[str].") + + invalid_workers = [w for w in workers if w not in available_workers] + if invalid_workers: + raise ValueError( + f"The following workers are not available: {invalid_workers}. " + f"Available workers are: {available_workers}." + ) + self._workers = workers def deriv(self, m, f=None): From 4397967d3ca28af84b180557103a43ab55ef5e7b Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 2 Oct 2025 15:23:32 -0400 Subject: [PATCH 3/3] Clean up --- simpeg/dask/simulation.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 3ad3aaf759..60bab836c0 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -171,40 +171,3 @@ def _get_client_worker(self) -> tuple: Sim._get_client_worker = _get_client_worker - - -# TODO: Make dpred parallel -# def dpred(self, m=None, f=None): -# r"""Predicted data for the model provided. -# -# Parameters -# ---------- -# m : (n_param,) numpy.ndarray -# The model parameters. -# f : simpeg.fields.Fields, optional -# If provided, will be used to compute the predicted data -# without recalculating the fields. -# -# Returns -# ------- -# (n_data, ) numpy.ndarray -# The predicted data vector. -# """ -# if self.survey is None: -# raise AttributeError( -# "The survey has not yet been set and is required to compute " -# "data. Please set the survey for the simulation: " -# "simulation.survey = survey" -# ) -# -# if f is None: -# if m is None: -# m = self.model -# -# f = self.fields(m) -# -# data = Data(self.survey) -# for src in self.survey.source_list: -# for rx in src.receiver_list: -# data[src, rx] = rx.eval(src, self.mesh, f) -# return mkvc(data)