diff --git a/simpeg_drivers/options.py b/simpeg_drivers/options.py index 93554e9e..6c56f093 100644 --- a/simpeg_drivers/options.py +++ b/simpeg_drivers/options.py @@ -17,7 +17,6 @@ import numpy as np from geoapps_utils.base import Options -from geoapps_utils.utils.importing import GeoAppsError from geoh5py.data import ( BooleanData, DataAssociationEnum, @@ -92,6 +91,16 @@ def at_least_one(cls, data): raise ValueError("Must provide either topography or active model.") return data + @model_validator(mode="before") + @classmethod + def topo_grid_must_have_elevation_channel(cls, data): + if isinstance(data.get("topography_object", None), Grid2D): + if data.get("topography", None) is None: + raise ValueError( + "Grid2D topography must be accompanied by a valid elevation channel." + ) + return data + @model_serializer(mode="wrap") def serialize_model(self, handler) -> dict[str, Any]: result = handler(self) diff --git a/tests/validations_test.py b/tests/validations_test.py index c01beeae..a97ba3ef 100644 --- a/tests/validations_test.py +++ b/tests/validations_test.py @@ -11,6 +11,7 @@ import pytest from geoapps_utils.utils.importing import GeoAppsError from geoh5py import Workspace +from geoh5py.objects import Grid2D from simpeg_drivers.options import CoreOptions @@ -23,3 +24,25 @@ def test_topo_or_active_validation(tmp_path): } with pytest.raises(GeoAppsError, match="active_cells: Value error, Must"): CoreOptions.build(data) + + +def test_topo_grid_missing_elevation(tmp_path): + with Workspace(tmp_path / "test.geoh5") as workspace: + grid = Grid2D.create( + workspace, + name="grid", + u_cell_size=10, + v_cell_size=10, + u_count=10, + v_count=10, + origin=[0, 0, 0], + ) + + data = { + "geoh5": workspace, + "inversion_type": "mvi", + "topography_object": grid, + "topography": None, + } + with pytest.raises(GeoAppsError, match="active_cells: Value error, Grid2D"): + CoreOptions.build(data)