diff --git a/simpeg_drivers/components/meshes.py b/simpeg_drivers/components/meshes.py index 91df9b91..613292ef 100644 --- a/simpeg_drivers/components/meshes.py +++ b/simpeg_drivers/components/meshes.py @@ -18,7 +18,7 @@ from discretize import TensorMesh, TreeMesh from geoh5py import Workspace from geoh5py.groups import UIJsonGroup -from geoh5py.objects import DrapeModel, Octree +from geoh5py.objects import DrapeModel, Grid2D, Octree, Points from grid_apps.octree_creation.driver import OctreeDriver from grid_apps.octree_creation.options import OctreeOptions from grid_apps.utils import octree_2_treemesh, treemesh_2_octree @@ -113,9 +113,19 @@ def get_entity(self) -> Octree | DrapeModel: def _auto_mesh(self): """Automate meshing based on data and topography objects.""" + topography = self.params.active_cells.topography_object + if isinstance(topography, Grid2D): + with Workspace() as ws: + vertices = topography.centroids.copy() + if self.params.active_cells.topography is not None: + vertices = np.column_stack( + [vertices[:, :2], self.params.active_cells.topography.values] + ) + topography = Points.create(ws, vertices=vertices) + params = auto_mesh_parameters( - self.params.data_object, - self.params.active_cells.topography_object, + survey=self.params.data_object, + topography=topography, inversion_type=self.params.inversion_type, ) driver = OctreeDriver(params) diff --git a/tests/meshes_test.py b/tests/meshes_test.py index 502ae3fa..969e96fe 100644 --- a/tests/meshes_test.py +++ b/tests/meshes_test.py @@ -16,12 +16,13 @@ import pytest from discretize import TreeMesh from geoh5py import Workspace -from geoh5py.objects import Octree +from geoh5py.objects import Grid2D, Octree from grid_apps.utils import treemesh_2_octree from simpeg_drivers.components import InversionMesh from simpeg_drivers.options import ActiveCellsOptions from simpeg_drivers.potential_fields import MVIInversionOptions +from simpeg_drivers.potential_fields.magnetic_vector.driver import MVIInversionDriver from simpeg_drivers.utils.synthetics.driver import SyntheticsComponents from simpeg_drivers.utils.synthetics.options import ( MeshOptions, @@ -32,7 +33,7 @@ from tests.utils.targets import get_workspace -def get_mvi_params(tmp_path: Path) -> MVIInversionOptions: +def get_mvi_params(tmp_path: Path, updates=None) -> MVIInversionOptions: opts = SyntheticsComponentsOptions( method="magnetic_vector", survey=SurveyOptions(n_stations=4, n_lines=4), @@ -52,20 +53,22 @@ def get_mvi_params(tmp_path: Path) -> MVIInversionOptions: {"elevation": {"values": components.topography.vertices[:, 2]}} ) - params = MVIInversionOptions.build( - geoh5=geoh5, - data_object=components.survey, - tmi_channel=tmi_channel, - tmi_uncertainty=0.01, - active_cells=ActiveCellsOptions( - topography_object=components.topography, topography=elevation - ), - inducing_field_strength=50000.0, - inducing_field_inclination=60.0, - inducing_field_declination=30.0, - mesh=mesh, - starting_model=components.model, - ) + kwargs = { + "geoh5": geoh5, + "data_object": components.survey, + "tmi_channel": tmi_channel, + "tmi_uncertainty": 0.01, + "topography_object": components.topography, + "topography": elevation, + "inducing_field_strength": 50000.0, + "inducing_field_inclination": 60.0, + "inducing_field_declination": 30.0, + "mesh": mesh, + "starting_model": components.model, + } + if updates is not None: + kwargs.update(updates) + params = MVIInversionOptions.build(**kwargs) return params @@ -214,3 +217,27 @@ def test_raise_on_rotated_negative_cell_size(tmp_path): msg = "Cannot convert negative cell sizes for rotated mesh." with pytest.raises(ValueError, match=msg): InversionMesh.ensure_cell_convention(mesh) + + +def test_handle_grid2d(tmp_path): + with Workspace(tmp_path / "test.geoh5") as ws: + topo = Grid2D.create( + ws, + name="topography", + u_cell_size=25.0, + v_cell_size=25.0, + u_count=16, + v_count=16, + origin=(-200.0, -200.0, 0.0), + ) + elev = topo.add_data( + { + "elevation": { + "values": np.zeros(len(topo.centroids)), + } + } + ) + + updates = {"mesh": None, "topography_object": topo, "topography": elev} + params = get_mvi_params(tmp_path, updates=updates) + MVIInversionDriver(params) # Doesn't crash