Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion simpeg_drivers/components/factories/entity_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import TYPE_CHECKING

import numpy as np
from geoapps_utils.utils.importing import GeoAppsError
from geoh5py.objects import (
CurrentElectrode,
Curve,
Expand Down Expand Up @@ -128,7 +129,7 @@ def _validate_large_loop_cells(
Validate that the transmitter loops are counter-clockwise sorted and closed.
"""
if transmitter.receivers.tx_id_property is None:
raise ValueError(
raise GeoAppsError(
"Transmitter ID property required for LargeLoopGroundTEMReceivers"
)

Expand Down
6 changes: 3 additions & 3 deletions simpeg_drivers/components/factories/survey_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np
import simpeg.electromagnetics.time_domain as tdem
from geoh5py.objects.surveys.electromagnetics.airborne_fem import AirborneFEMReceivers
from geoapps_utils.utils.importing import GeoAppsError
from geoh5py.objects.surveys.electromagnetics.ground_tem import (
LargeLoopGroundTEMTransmitters,
)
Expand Down Expand Up @@ -207,14 +207,14 @@ def _tdem_arguments(self, data=None):
if receivers.channels[-1] > (
receivers.waveform[:, 0].max() - receivers.timing_mark
):
raise ValueError(
raise GeoAppsError(
f"The latest time channel {receivers.channels[-1]} exceeds "
f"the waveform discretization. Revise waveform."
)

if isinstance(transmitters, LargeLoopGroundTEMTransmitters):
if receivers.tx_id_property is None:
raise ValueError(
raise GeoAppsError(
"Transmitter ID property required for LargeLoopGroundTEMReceivers"
)

Expand Down
32 changes: 24 additions & 8 deletions simpeg_drivers/components/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
from typing import TYPE_CHECKING

import numpy as np
from geoapps_utils.base import Driver
from geoapps_utils.utils.importing import GeoAppsError
from geoapps_utils.utils.numerical import weighted_average
from geoapps_utils.utils.transformations import rotate_xyz
from geoh5py.data import Data, FloatData, NumericData
from geoh5py.data.data_type import GeometricDataValueMapType
from geoh5py.objects import ObjectBase
from simpeg.utils.mat_utils import (
cartesian2amplitude_dip_azimuth,
dip_azimuth2cartesian,
mkvc,
)
Expand Down Expand Up @@ -76,15 +74,25 @@ def __init__(self, driver: InversionDriver):
self.is_sigma = self.driver.params.physical_property == "conductivity"
self.is_vector = self.driver.params.inversion_type == "magnetic vector"

self._starting_model = InversionModel(driver, "starting_model")
self._starting_model = InversionModel(
driver, "starting_model", is_sigma=self.is_sigma
)
self._starting_inclination = InversionModel(driver, "starting_inclination")
self._starting_declination = InversionModel(driver, "starting_declination")
self._reference_model = InversionModel(driver, "reference_model")
self._reference_model = InversionModel(
driver, "reference_model", is_sigma=self.is_sigma
)
self._reference_inclination = InversionModel(driver, "reference_inclination")
self._reference_declination = InversionModel(driver, "reference_declination")
self._lower_bound = InversionModel(driver, "lower_bound")
self._upper_bound = InversionModel(driver, "upper_bound")
self._conductivity_model = InversionModel(driver, "conductivity_model")
self._lower_bound = InversionModel(
driver, "lower_bound", is_sigma=self.is_sigma
)
self._upper_bound = InversionModel(
driver, "upper_bound", is_sigma=self.is_sigma
)
self._conductivity_model = InversionModel(
driver, "conductivity_model", is_sigma=self.is_sigma
)
self._alpha_s = InversionModel(driver, "alpha_s")
self._length_scale_x = InversionModel(driver, "length_scale_x")
self._length_scale_y = InversionModel(driver, "length_scale_y")
Expand Down Expand Up @@ -498,6 +506,7 @@ def __init__(
driver: InversionDriver,
model_type: str,
trim_active_cells: bool = True,
is_sigma: bool = False,
):
"""
:param driver: InversionDriver object.
Expand All @@ -507,6 +516,7 @@ def __init__(
"""
self.driver = driver
self.model_type = model_type
self.is_sigma = is_sigma
self.model: np.ndarray | None = None
self.trim_active_cells = trim_active_cells
self._initialize()
Expand All @@ -522,6 +532,12 @@ def _initialize(self):
model = self._get(self.model_type)

if model is not None:
if self.is_sigma and np.any(model <= 0):
raise GeoAppsError(
f"All values in {self.model_type} must be positive when "
"inversion is in log-conductivity space."
)

self.model = mkvc(model)

if isinstance(self._fetch_reference(self.model_type), Data):
Expand Down
22 changes: 15 additions & 7 deletions simpeg_drivers/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@
mlogger.setLevel(logging.WARNING)


logger = logging.getLogger("simpeg-drivers")


class InversionDriver(Driver):
_options_class = BaseForwardOptions | BaseInversionOptions
_inversion_type: str | None = None
Expand Down Expand Up @@ -646,14 +649,19 @@ def start(
else:
ifile = InputFile.read_ui_json(filepath, **kwargs)

if driver_class is None:
driver = cls.from_input_file(ifile)
else:
with ifile.data["geoh5"].open(mode="r+"):
params = driver_class._options_class.build(ifile)
driver = driver_class(params)
try:
if driver_class is None:
driver = cls.from_input_file(ifile)
else:
with ifile.data["geoh5"].open(mode="r+"):
params = driver_class._options_class.build(ifile)
driver = driver_class(params)

driver.run()

driver.run()
except GeoAppsError as error:
logger.warning("\n\nApplicationError: %s\n\n", error)
sys.exit(1)

return driver

Expand Down
3 changes: 2 additions & 1 deletion simpeg_drivers/electromagnetics/frequency_domain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import ClassVar, TypeAlias

from geoapps_utils.utils.importing import GeoAppsError
from geoh5py.groups import PropertyGroup
from geoh5py.objects import (
AirborneFEMReceivers,
Expand Down Expand Up @@ -56,7 +57,7 @@ def tx_offsets(self):

except KeyError as exception:
msg = "Metadata must contain 'Frequency configurations' dictionary with 'Offset' data."
raise KeyError(msg) from exception
raise GeoAppsError(msg) from exception

return tx_offsets

Expand Down
3 changes: 2 additions & 1 deletion simpeg_drivers/joint/joint_surveys/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pathlib import Path
from typing import ClassVar

from geoapps_utils.utils.importing import GeoAppsError
from pydantic import model_validator

from simpeg_drivers import assets_path
Expand All @@ -38,7 +39,7 @@ class JointSurveysOptions(BaseJointOptions):
def all_groups_same_physical_property(self):
physical_properties = [k.options["physical_property"] for k in self.groups]
if len(list(set(physical_properties))) > 1:
raise ValueError(
raise GeoAppsError(
"All physical properties must be the same. "
f"Provided SimPEG groups for {physical_properties}."
)
Expand Down
3 changes: 2 additions & 1 deletion simpeg_drivers/line_sweep/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
from geoapps_utils.param_sweeps.driver import SweepDriver, SweepParams
from geoapps_utils.param_sweeps.generate import generate
from geoapps_utils.utils.importing import GeoAppsError
from geoh5py.data import FilenameData
from geoh5py.groups import SimPEGGroup
from geoh5py.objects import DrapeModel, PotentialElectrode
Expand Down Expand Up @@ -147,7 +148,7 @@ def collect_results(self):
)

if not line_data:
raise ValueError(f"Line {line} not found in {survey.name}")
raise GeoAppsError(f"Line {line} not found in {survey.name}")

line_indices = line_ids == line
data = self.collect_line_data(survey, line_indices, data)
Expand Down
12 changes: 6 additions & 6 deletions simpeg_drivers/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
from geoapps_utils.base import Options
from geoapps_utils.utils.importing import GeoAppsError
from geoh5py.data import (
BooleanData,
DataAssociationEnum,
Expand All @@ -28,9 +29,8 @@
from geoh5py.groups import PropertyGroup, SimPEGGroup, UIJsonGroup
from geoh5py.objects import DrapeModel, Grid2D, Octree, Points
from geoh5py.objects.surveys.electromagnetics.base import BaseEMSurvey
from geoh5py.shared.utils import fetch_active_workspace
from geoh5py.ui_json import InputFile
from geoh5py.ui_json.templates import data_parameter
from geoh5py.ui_json.utils import fetch_active_workspace
from pydantic import (
AliasChoices,
BaseModel,
Expand Down Expand Up @@ -89,7 +89,7 @@ class ActiveCellsOptions(BaseModel):
@classmethod
def at_least_one(cls, data):
if all(v is None for v in data.values()):
raise ValueError("Must provide either topography or active model.")
raise GeoAppsError("Must provide either topography or active model.")
return data

@model_serializer(mode="wrap")
Expand Down Expand Up @@ -208,7 +208,7 @@ def _component_name(self, component: str) -> str:
@classmethod
def mesh_cannot_be_rotated(cls, value: Octree):
if isinstance(value, Octree) and value.rotation not in [0.0, None]:
raise ValueError(
raise GeoAppsError(
"Rotated meshes are not supported. Please use a mesh with an angle of 0.0."
)
return value
Expand Down Expand Up @@ -520,13 +520,13 @@ class LineSelectionOptions(BaseModel):
@classmethod
def validate_cell_association(cls, value):
if value.association is not DataAssociationEnum.CELL:
raise ValueError("Line identifier must be associated with cells.")
raise GeoAppsError("Line identifier must be associated with cells.")
return value

@model_validator(mode="after")
def line_id_referenced(self):
if self.line_id not in self.line_object.values:
raise ValueError("Line id isn't referenced in the line object.")
raise GeoAppsError("Line id isn't referenced in the line object.")
return self


Expand Down
43 changes: 41 additions & 2 deletions tests/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
from pathlib import Path

import numpy as np
from geoh5py import Workspace
import pytest
from geoapps_utils.utils.importing import GeoAppsError
from geoh5py.objects import Points

from simpeg_drivers.components import (
InversionMesh,
InversionModel,
InversionModelCollection,
)
from simpeg_drivers.electricals.direct_current.three_dimensions.driver import (
DC3DForwardDriver,
)
from simpeg_drivers.electricals.direct_current.three_dimensions.options import (
DC3DForwardOptions,
)
from simpeg_drivers.options import ActiveCellsOptions
from simpeg_drivers.potential_fields import MVIInversionOptions
from simpeg_drivers.potential_fields.magnetic_vector.driver import (
Expand All @@ -43,7 +50,7 @@ def get_mvi_params(tmp_path: Path) -> MVIInversionOptions:
mesh=MeshOptions(refinement=(2,)),
model=ModelOptions(anomaly=0.05),
)
with get_workspace(tmp_path / "inversion_test.ui.geoh5") as geoh5:
with get_workspace(tmp_path / f"{__name__}.ui.geoh5") as geoh5:
components = SyntheticsComponents(geoh5, options=opts)
mesh = components.model.parent
ref_inducing = mesh.add_data(
Expand Down Expand Up @@ -81,6 +88,38 @@ def get_mvi_params(tmp_path: Path) -> MVIInversionOptions:
return params


def get_dc_params(tmp_path: Path) -> MVIInversionOptions:
opts = SyntheticsComponentsOptions(
method="direct_current",
survey=SurveyOptions(n_stations=4, n_lines=2),
mesh=MeshOptions(refinement=(2,)),
model=ModelOptions(anomaly=0.05),
)
with get_workspace(tmp_path / f"{__name__}.ui.geoh5") as geoh5:
components = SyntheticsComponents(geoh5, options=opts)
mesh = components.model.parent
params = DC3DForwardOptions.build(
geoh5=geoh5,
data_object=components.survey,
tmi_channel_bool=True,
mesh=mesh,
active_cells=ActiveCellsOptions(topography_object=components.topography),
starting_model=-1e-04,
)

return params


def test_negative_reference_model(tmp_path: Path):
params = get_dc_params(tmp_path)
geoh5 = params.geoh5
with geoh5.open():
driver = DC3DForwardDriver(params)

with pytest.raises(GeoAppsError, match="must be positive when"):
_ = driver.models.starting_model


def test_zero_reference_model(tmp_path: Path):
params = get_mvi_params(tmp_path)
geoh5 = params.geoh5
Expand Down
4 changes: 2 additions & 2 deletions tests/run_tests/driver_airborne_tem_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pathlib import Path

import numpy as np
from geoh5py.groups import SimPEGGroup
from geoapps_utils.utils.importing import GeoAppsError
from geoh5py.workspace import Workspace
from pytest import raises

Expand Down Expand Up @@ -71,7 +71,7 @@ def test_bad_waveform(tmp_path: Path):

params.data_object.channels[-1] = 1000.0

with raises(ValueError, match="The latest time"):
with raises(GeoAppsError, match="The latest time"):
_ = fwr_driver.inversion_data.survey


Expand Down
2 changes: 1 addition & 1 deletion tests/run_tests/driver_joint_surveys_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_joint_surveys_conductivity_run(
method="direct-current",
survey=SurveyOptions(n_stations=4, n_lines=4, name="survey A"),
mesh=MeshOptions(refinement=(2, 2, 2), name="mesh A"),
model=ModelOptions(anomaly=0.1, name="model A"),
model=ModelOptions(anomaly=0.1, background=0.01, name="model A"),
active=SyntheticsActiveCellsOptions(name="active A"),
)

Expand Down
Loading