Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions simpeg/dask/electromagnetics/frequency_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import scipy.sparse as sp

from dask import array, compute, delayed
from dask.distributed import get_client
from simpeg.dask.utils import get_parallel_blocks
from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary
import zarr
Expand Down Expand Up @@ -104,20 +103,17 @@ def getSourceTerm(self, freq, source=None):

if source is None:

try:
client = get_client()
sim = client.scatter(self, workers=self.worker)
except ValueError:
client = None
sim = self

client, worker = self._get_client_worker()
source_list = self.survey.get_sources_by_frequency(freq)
source_blocks = np.array_split(
np.arange(len(source_list)), self.n_threads(client=client)
np.arange(len(source_list)), self.n_threads(client=client, worker=worker)
)

if client:
source_list = client.scatter(source_list, workers=self.worker)
sim = client.scatter(self, workers=self.worker)
source_list = client.scatter(source_list, workers=worker)
else:
sim = self

block_compute = []

Expand All @@ -127,9 +123,7 @@ def getSourceTerm(self, freq, source=None):

if client:
block_compute.append(
client.submit(
source_eval, sim, source_list, block, workers=self.worker
)
client.submit(source_eval, sim, source_list, block, workers=worker)
)
else:
block_compute.append(source_eval(sim, source_list, block))
Expand Down Expand Up @@ -221,12 +215,7 @@ def compute_J(self, m, f=None):
fields_array = f[:, self._solutionType]
blocks_receiver_derivs = []

try:
client = get_client()
worker = self.worker
except ValueError:
client = None
worker = None
client, worker = self._get_client_worker()

if client:
fields_array = client.scatter(f[:, self._solutionType], workers=worker)
Expand Down
21 changes: 11 additions & 10 deletions simpeg/dask/electromagnetics/static/resistivity/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix

from .....utils import Zero
from dask.distributed import get_client

import dask.array as da
import numpy as np
from scipy import sparse as sp
Expand Down Expand Up @@ -163,22 +163,23 @@ def getSourceTerm(self):
source_list = self.survey.source_list

indices = np.arange(len(source_list))
try:

client = get_client()
sim = client.scatter(self, workers=self.worker)
future_list = client.scatter(source_list, workers=self.worker)
indices = np.array_split(indices, self.n_threads(client=client))
client, worker = self._get_client_worker()

if client:
sim = client.scatter(self, workers=worker)
future_list = client.scatter(source_list, workers=worker)
indices = np.array_split(
indices, self.n_threads(client=client, worker=worker)
)
blocks = []
for ind in indices:
blocks.append(
client.submit(
source_eval, sim, future_list, ind, workers=self.worker
)
client.submit(source_eval, sim, future_list, ind, workers=worker)
)

blocks = sp.hstack(client.gather(blocks))
except ValueError:
else:
blocks = source_eval(self, source_list, indices)

self._q = blocks
Expand Down
51 changes: 24 additions & 27 deletions simpeg/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import scipy.sparse as sp
from dask import array, delayed
from dask.distributed import get_client


from time import time
from simpeg.dask.utils import get_parallel_blocks
Expand Down Expand Up @@ -89,10 +89,7 @@ def compute_J(self, m, f=None):
if f is None:
f, Ainv = self.fields(m=m, return_Ainv=True)

try:
client = get_client()
except ValueError:
client = None
client, worker = self._get_client_worker()

ftype = self._fieldType + "Solution"
sens_name = self.sensitivity_path[:-5]
Expand All @@ -118,22 +115,22 @@ def compute_J(self, m, f=None):
blocks = get_parallel_blocks(
self.survey.source_list,
compute_row_size,
thread_count=self.n_threads(client=client),
thread_count=self.n_threads(client=client, worker=worker),
)
fields_array = f[:, ftype, :]

if len(self.survey.source_list) == 1:
fields_array = fields_array[:, np.newaxis, :]

times_field_derivs, Jmatrix = compute_field_derivs(
self, f, blocks, Jmatrix, fields_array.shape, client
self, f, blocks, Jmatrix, fields_array.shape
)

ATinv_df_duT_v = [[] for _ in blocks]

if client:
fields_array = client.scatter(fields_array, workers=self.worker)
sim = client.scatter(self, workers=self.worker)
fields_array = client.scatter(fields_array, workers=worker)
sim = client.scatter(self, workers=worker)
else:
delayed_compute_rows = delayed(compute_rows)
sim = self
Expand Down Expand Up @@ -161,7 +158,7 @@ def compute_J(self, m, f=None):
)

if client:
field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker)
field_derivatives = client.scatter(ATinv_df_duT_v, workers=worker)
else:
field_derivatives = ATinv_df_duT_v

Expand All @@ -181,7 +178,7 @@ def compute_J(self, m, f=None):
field_derivatives,
fields_array,
time_mask,
workers=self.worker,
workers=worker,
)
)
else:
Expand Down Expand Up @@ -267,17 +264,19 @@ def evaluate_receivers(block, mesh, time_mesh, fields, fields_array):
return np.hstack(data)


def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client):
def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape):
"""
Compute the derivative of the fields
"""
delayed_chunks = []

client, worker = self._get_client_worker()

if client:
mesh = client.scatter(self.mesh, workers=self.worker)
time_mesh = client.scatter(self.time_mesh, workers=self.worker)
fields = client.scatter(fields, workers=self.worker)
source_list = client.scatter(self.survey.source_list, workers=self.worker)
mesh = client.scatter(self.mesh, workers=worker)
time_mesh = client.scatter(self.time_mesh, workers=worker)
fields = client.scatter(fields, workers=worker)
source_list = client.scatter(self.survey.source_list, workers=worker)
else:
mesh = self.mesh
time_mesh = self.time_mesh
Expand All @@ -300,7 +299,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client):
time_mesh,
fields,
self.model.size,
workers=self.worker,
workers=worker,
)
)
else:
Expand Down Expand Up @@ -537,24 +536,22 @@ def dpred(self, m=None, f=None):
"simulation.survey = survey"
)

try:
client = get_client()
except ValueError:
client = None
client, worker = self._get_client_worker()

if f is None:
f = self.fields(m)

delayed_chunks = []

source_block = np.array_split(
np.arange(len(self.survey.source_list)), self.n_threads(client=client)
np.arange(len(self.survey.source_list)),
self.n_threads(client=client, worker=worker),
)
if client:
mesh = client.scatter(self.mesh, workers=self.worker)
time_mesh = client.scatter(self.time_mesh, workers=self.worker)
fields = client.scatter(f, workers=self.worker)
source_list = client.scatter(self.survey.source_list, workers=self.worker)
mesh = client.scatter(self.mesh, workers=worker)
time_mesh = client.scatter(self.time_mesh, workers=worker)
fields = client.scatter(f, workers=worker)
source_list = client.scatter(self.survey.source_list, workers=worker)
else:
mesh = self.mesh
time_mesh = self.time_mesh
Expand All @@ -575,7 +572,7 @@ def dpred(self, m=None, f=None):
mesh,
time_mesh,
fields,
workers=self.worker,
workers=worker,
)
)
else:
Expand Down
26 changes: 22 additions & 4 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Callable
import numpy as np

from dask.distributed import Client, Future
from dask.distributed import Client, Future, get_client
from ..data_misfit import L2DataMisfit

from simpeg.utils import validate_list_of_types
Expand Down Expand Up @@ -145,9 +145,7 @@ def _validate_type_or_future_of_type(
):

if workers is None:
workers = [
(worker.worker_address,) for worker in client.cluster.workers.values()
]
workers = [(worker,) for worker in client.nthreads()]

objects = validate_list_of_types(
property_name, objects, obj_type, ensure_unique=True
Expand Down Expand Up @@ -262,6 +260,9 @@ def client(self):

@client.setter
def client(self, client):
if client is None:
client = get_client()

if not isinstance(client, Client):
raise TypeError("client must be a dask.distributed.Client")

Expand All @@ -279,6 +280,23 @@ def workers(self, workers):
if not isinstance(workers, list | type(None)):
raise TypeError("workers must be a list of strings")

available_workers = [(worker,) for worker in self.client.nthreads()]

if workers is None:
workers = available_workers

if not isinstance(workers, list) or not all(
isinstance(w, tuple) for w in workers
):
raise TypeError("Workers must be a list of tuple[str].")

invalid_workers = [w for w in workers if w not in available_workers]
if invalid_workers:
raise ValueError(
f"The following workers are not available: {invalid_workers}. "
f"Available workers are: {available_workers}."
)

self._workers = workers

def deriv(self, m, f=None):
Expand Down
11 changes: 4 additions & 7 deletions simpeg/dask/potential_fields/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from ...potential_fields.base import BasePFSimulation as Sim
from dask.distributed import get_client

import os
from dask import delayed, array, config
from dask.diagnostics import ProgressBar
Expand Down Expand Up @@ -60,13 +60,10 @@ def linear_operator(self):
)
block_split = np.array_split(self.survey.receiver_locations, n_blocks)

try:
client = get_client()
except ValueError:
client = None
client, worker = self._get_client_worker()

if client:
sim = client.scatter(self, workers=self.worker)
sim = client.scatter(self, workers=worker)
else:
delayed_compute = delayed(block_compute)

Expand All @@ -79,7 +76,7 @@ def linear_operator(self):
sim,
block,
self.survey.components,
workers=self.worker,
workers=worker,
)
)
else:
Expand Down
Loading