Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions simpeg_drivers/joint/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def initialize(self):
self.models.active_cells = global_actives
for driver, wire in zip(self.drivers, self.wires, strict=True):
logger.info("Initializing driver %s", driver.params.name)
# Create a projection from global mesh to driver specific mesh
projection = TileMap(
self.inversion_mesh.mesh,
global_actives,
Expand All @@ -140,6 +141,8 @@ def initialize(self):
tile_map = projection * wire
driver.params.active_model = None
driver.models.active_cells = projection.local_active

# Keep a copy on the top combo/future for saving directives and model creation
driver.data_misfit.model_map = tile_map

multipliers = []
Expand Down
28 changes: 20 additions & 8 deletions simpeg_drivers/joint/joint_surveys/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,25 @@ def __init__(self, params: JointSurveysOptions):

def validate_create_models(self):
"""Check if all models were provided, otherwise use the first driver models."""
# Create projection for first driver to global mesh
mapping = maps.TileMap(
self.inversion_mesh.mesh,
self.models.active_cells,
self.drivers[0].inversion_mesh.mesh,
enforce_active=False,
)
projection = mapping.deriv(np.ones(self.models.n_active)).T
norm = np.array(np.sum(projection, axis=1)).flatten()

for model_type in self.models.model_types:
model = getattr(self.models, model_type)
if model is not None or getattr(self.drivers[0].models, model_type) is None:
continue

model_local_values = getattr(self.drivers[0].models, model_type)
projection = (
self.drivers[0]
.data_misfit.model_map.deriv(np.ones(self.models.n_active))
.T
)
norm = np.array(np.sum(projection, axis=1)).flatten()
model = (projection * model_local_values) / (norm + 1e-8)
model = (
projection * model_local_values[: self.drivers[0].models.n_active]
) / (norm + 1e-8)

if self.drivers[0].models.is_sigma and model_type in [
"starting_model",
Expand All @@ -70,11 +76,17 @@ def validate_create_models(self):

getattr(self.models, f"_{model_type}").model = model

# For MVI, set is_vector from first driver
self.models.is_vector = self.drivers[0].models.is_vector

@property
def wires(self):
"""Model projections"""
if self._wires is None:
wires = [maps.IdentityMap(nP=self.models.n_active) for _ in self.drivers]
wires = [
maps.IdentityMap(nP=self.models.n_active * driver.n_blocks)
for driver in self.drivers
]
self._wires = wires

return self._wires
Expand Down
13 changes: 0 additions & 13 deletions simpeg_drivers/joint/joint_surveys/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,6 @@ class JointSurveysOptions(BaseJointOptions):

models: JointSurveysModelOptions

@field_validator("group_a", "group_b", "group_c")
@classmethod
def no_mvi_groups(cls, val):
if val is None:
return val

if "magnetic vector" in val.options.get("inversion_type", ""):
raise ValueError(
f"Joint inversion doesn't currently support MVI data as passed in "
f"the group: {val.name}."
)
return val

@model_validator(mode="after")
def all_groups_same_physical_property(self):
physical_properties = [k.options["physical_property"] for k in self.groups]
Expand Down
8 changes: 4 additions & 4 deletions simpeg_drivers/utils/synthetics/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
from geoh5py import Workspace
from geoh5py.data import FloatData
from geoh5py.data import BooleanData, FloatData
from geoh5py.objects import DrapeModel, ObjectBase, Octree, Surface

from simpeg_drivers.utils.synthetics.meshes.factory import get_mesh
Expand Down Expand Up @@ -86,16 +86,16 @@ def mesh(self):

@property
def active(self):
entity = self.geoh5.get_entity(self.options.active.name)[0]
assert isinstance(entity, FloatData | type(None))
entity = self.mesh.get_entity(self.options.active.name)[0]
assert isinstance(entity, BooleanData | type(None))
if entity is None:
entity = get_active(self.mesh, self.topography)
self._active = entity
return self._active

@property
def model(self):
entity = self.geoh5.get_entity(self.options.model.name)[0]
entity = self.mesh.get_entity(self.options.model.name)[0]
assert isinstance(entity, FloatData | type(None))
if entity is None:
assert self.options is not None
Expand Down
87 changes: 86 additions & 1 deletion tests/run_tests/driver_joint_surveys_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from geoh5py.objects import Octree
from geoh5py.workspace import Workspace
from simpeg.directives import SavePropertyGroup
from simpeg.directives import SaveModelGeoH5, SavePropertyGroup

from simpeg_drivers.electricals.direct_current.three_dimensions.driver import (
DC3DInversionDriver,
Expand All @@ -31,6 +31,8 @@
GravityInversionOptions,
)
from simpeg_drivers.potential_fields.gravity.driver import GravityInversionDriver
from simpeg_drivers.potential_fields.magnetic_vector.driver import MVIInversionDriver
from simpeg_drivers.potential_fields.magnetic_vector.options import MVIInversionOptions
from simpeg_drivers.utils.synthetics.driver import (
SyntheticsComponents,
)
Expand Down Expand Up @@ -195,6 +197,89 @@ def test_joint_surveys_inv_run(
check_target(output, target_run)


def test_joint_surveys_mvi_run(tmp_path, anomaly=0.05):
drivers = []

with Workspace.create(tmp_path / f"{__name__}.geoh5") as geoh5:
for ii in range(1, 3):
opts = SyntheticsComponentsOptions(
method="magnetic_vector",
survey=SurveyOptions(
n_stations=3**ii,
n_lines=3**ii,
drape=5.0,
name=f"Survey Driver[{ii}]",
),
mesh=MeshOptions(refinement=(2**ii, 2, 2), name=f"Mesh Driver[{ii}]"),
model=ModelOptions(anomaly=anomaly),
)
components = SyntheticsComponents(geoh5, options=opts)
survey = components.survey
obs, uncrt = survey.add_data(
{
"TMI": {"values": np.random.randn(survey.n_vertices)},
"Uncertainty": {"values": np.ones(survey.n_vertices) * 1e-3},
}
)

# Add an inclination model on the first driver only to test handling of
# models from the main driver
if ii == 1:
model = components.model.values
model[model > 0] = 45.0
model[model <= 0] = 90.0
inc_mod = components.mesh.add_data(
{"Inclination Model": {"values": model}}
)
else:
inc_mod = None

params = MVIInversionOptions.build(
geoh5=geoh5,
mesh=components.mesh,
topography_object=components.topography,
tmi_channel=obs,
tmi_uncertainty=uncrt,
inducing_field_strength=45000,
inducing_field_inclination=90.0,
inducing_field_declination=0.0,
data_object=survey,
starting_model=components.model,
starting_inclination=inc_mod,
reference_model=0.0,
)
drivers.append(MVIInversionDriver(params))

# Run the inverse
joint_params = JointSurveysOptions.build(
geoh5=geoh5,
active_cells=ActiveCellsOptions(topography_object=components.topography),
group_a=drivers[0].out_group,
group_b=drivers[1].out_group,
starting_model=0.01,
# Default to Conductivity (S/m)
)

driver = JointSurveyDriver(joint_params)
assert np.isclose(driver.models.reference_model[0], 0) # Took it from driver_A
assert driver.models.starting_model.shape == (driver.models.n_active * 3,)
assert np.isclose(
driver.models.starting_model.max(), 0.01 * np.cos(np.deg2rad(45.0))
)

# Test saving the starting models on each mesh (open file to validate)
assert (
len(
[
directive.write(0, driver.models.starting_model)
for directive in driver.directives.directive_list
if isinstance(directive, SaveModelGeoH5)
]
)
== 3
)


def test_joint_surveys_conductivity_run(
tmp_path,
):
Expand Down