diff --git a/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json b/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json index 6a1f5e6b..2d6357f4 100644 --- a/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json @@ -277,6 +277,18 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "visible": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json b/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json index 6498ea8f..563028b3 100644 --- a/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json @@ -232,6 +232,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json b/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json index a0d7a9f2..6337065a 100644 --- a/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json @@ -259,6 +259,18 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "visible": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/fem_inversion.ui.json b/simpeg_drivers-assets/uijson/fem_inversion.ui.json index f85c9627..fc3c00cc 100644 --- a/simpeg_drivers-assets/uijson/fem_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/fem_inversion.ui.json @@ -268,6 +268,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/gravity_inversion.ui.json b/simpeg_drivers-assets/uijson/gravity_inversion.ui.json index 9f61329f..389ba540 100644 --- a/simpeg_drivers-assets/uijson/gravity_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/gravity_inversion.ui.json @@ -501,6 +501,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json index f7cdec99..940c72e0 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json @@ -286,6 +286,18 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "visible": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json index d32f6e30..829f197e 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json @@ -248,6 +248,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json index 5b74c35c..bec2dade 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json @@ -270,6 +270,18 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "visible": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/joint_surveys_inversion.ui.json b/simpeg_drivers-assets/uijson/joint_surveys_inversion.ui.json index 5b8b832c..3cf0c145 100644 --- a/simpeg_drivers-assets/uijson/joint_surveys_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/joint_surveys_inversion.ui.json @@ -237,6 +237,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/magnetic_scalar_inversion.ui.json b/simpeg_drivers-assets/uijson/magnetic_scalar_inversion.ui.json index 3e2630f6..20defc67 100644 --- a/simpeg_drivers-assets/uijson/magnetic_scalar_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/magnetic_scalar_inversion.ui.json @@ -533,6 +533,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/magnetic_vector_inversion.ui.json b/simpeg_drivers-assets/uijson/magnetic_vector_inversion.ui.json index 05c44729..d08a75f0 100644 --- a/simpeg_drivers-assets/uijson/magnetic_vector_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/magnetic_vector_inversion.ui.json @@ -597,6 +597,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json b/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json index 220c8152..0f65c8b4 100644 --- a/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json @@ -453,6 +453,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/tdem_inversion.ui.json b/simpeg_drivers-assets/uijson/tdem_inversion.ui.json index 67a75b24..b6e49561 100644 --- a/simpeg_drivers-assets/uijson/tdem_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/tdem_inversion.ui.json @@ -305,6 +305,15 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers-assets/uijson/tipper_inversion.ui.json b/simpeg_drivers-assets/uijson/tipper_inversion.ui.json index 335f7528..50d78389 100644 --- a/simpeg_drivers-assets/uijson/tipper_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/tipper_inversion.ui.json @@ -333,6 +333,17 @@ "property": "", "enabled": true }, + "gradient_rotation": { + "group": "Regularization", + "association": "Cell", + "dataType": "Float", + "dataGroupType": "Dip direction & dip", + "label": "Gradient rotation", + "optional": true, + "enabled": false, + "parent": "mesh", + "value": "" + }, "s_norm": { "association": "Cell", "dataType": "Float", diff --git a/simpeg_drivers/components/models.py b/simpeg_drivers/components/models.py index 1e393ecf..32c46188 100644 --- a/simpeg_drivers/components/models.py +++ b/simpeg_drivers/components/models.py @@ -29,6 +29,25 @@ from simpeg_drivers.driver import InversionDriver +MODEL_TYPES = [ + "starting", + "reference", + "lower_bound", + "upper_bound", + "conductivity", + "alpha_s", + "length_scale_x", + "length_scale_y", + "length_scale_z", + "gradient_dip", + "gradient_direction", + "s_norm", + "x_norm", + "y_norm", + "z_norm", +] + + class InversionModelCollection: """ Collection of inversion models. @@ -39,21 +58,7 @@ class InversionModelCollection: """ - model_types = [ - "starting", - "reference", - "lower_bound", - "upper_bound", - "conductivity", - "alpha_s", - "length_scale_x", - "length_scale_y", - "length_scale_z", - "s_norm", - "x_norm", - "y_norm", - "z_norm", - ] + model_types = MODEL_TYPES def __init__(self, driver: InversionDriver): """ @@ -62,25 +67,34 @@ def __init__(self, driver: InversionDriver): self._active_cells: np.ndarray | None = None self._driver = driver self.is_sigma = self.driver.params.physical_property == "conductivity" - self.is_vector = ( + is_vector = ( True if self.driver.params.inversion_type == "magnetic vector" else False ) - self.n_blocks = ( - 3 if self.driver.params.inversion_type == "magnetic vector" else 1 + self._starting = InversionModel(driver, "starting", is_vector=is_vector) + self._reference = InversionModel(driver, "reference", is_vector=is_vector) + self._lower_bound = InversionModel(driver, "lower_bound", is_vector=is_vector) + self._upper_bound = InversionModel(driver, "upper_bound", is_vector=is_vector) + self._conductivity = InversionModel(driver, "conductivity", is_vector=is_vector) + self._alpha_s = InversionModel(driver, "alpha_s", is_vector=is_vector) + self._length_scale_x = InversionModel( + driver, "length_scale_x", is_vector=is_vector + ) + self._length_scale_y = InversionModel( + driver, "length_scale_y", is_vector=is_vector + ) + self._length_scale_z = InversionModel( + driver, "length_scale_z", is_vector=is_vector + ) + self._gradient_dip = InversionModel( + driver, "gradient_dip", trim_active_cells=False ) - self._starting = InversionModel(driver, "starting") - self._reference = InversionModel(driver, "reference") - self._lower_bound = InversionModel(driver, "lower_bound") - self._upper_bound = InversionModel(driver, "upper_bound") - self._conductivity = InversionModel(driver, "conductivity") - self._alpha_s = InversionModel(driver, "alpha_s") - self._length_scale_x = InversionModel(driver, "length_scale_x") - self._length_scale_y = InversionModel(driver, "length_scale_y") - self._length_scale_z = InversionModel(driver, "length_scale_z") - self._s_norm = InversionModel(driver, "s_norm") - self._x_norm = InversionModel(driver, "x_norm") - self._y_norm = InversionModel(driver, "y_norm") - self._z_norm = InversionModel(driver, "z_norm") + self._gradient_direction = InversionModel( + driver, "gradient_direction", trim_active_cells=False + ) + self._s_norm = InversionModel(driver, "s_norm", is_vector=is_vector) + self._x_norm = InversionModel(driver, "x_norm", is_vector=is_vector) + self._y_norm = InversionModel(driver, "y_norm", is_vector=is_vector) + self._z_norm = InversionModel(driver, "z_norm", is_vector=is_vector) @property def n_active(self) -> int: @@ -256,6 +270,20 @@ def length_scale_z(self) -> np.ndarray | None: return self._length_scale_z.model.copy() + @property + def gradient_dip(self) -> np.ndarray | None: + if self._gradient_dip.model is None: + return None + + return self._gradient_dip.model.copy() + + @property + def gradient_direction(self) -> np.ndarray | None: + if self._gradient_direction.model is None: + return None + + return self._gradient_direction.model.copy() + @property def s_norm(self) -> np.ndarray | None: if self._s_norm.model is None: @@ -291,7 +319,7 @@ def z_norm(self) -> np.ndarray | None: def _model_method_wrapper(self, method, name=None, **kwargs): """wraps individual model's specific method and applies in loop over model types.""" returned_items = {} - for mtype in self.model_types: + for mtype in MODEL_TYPES: model = getattr(self, f"_{mtype}") if model.model is not None: f = getattr(model, method) @@ -335,41 +363,24 @@ class InversionModel: remove_air: Use active cells vector to remove air cells from model. """ - model_types = [ - "starting", - "reference", - "lower_bound", - "upper_bound", - "conductivity", - "alpha_s", - "length_scale_x", - "length_scale_y", - "length_scale_z", - "s_norm", - "x_norm", - "y_norm", - "z_norm", - ] - def __init__( self, driver: InversionDriver, model_type: str, + is_vector: bool = False, + trim_active_cells: bool = True, ): """ :param driver: InversionDriver object. - :param model_type: Type of inversion model, can be any of "starting", "reference", - "lower_bound", "upper_bound". + :param model_type: Type of inversion model, can be any of MODEL_TYPES. + :param is_vector: If True, model is a vector. + :param trim_active_cells: If True, remove air cells from model. """ self.driver = driver self.model_type = model_type self.model: np.ndarray | None = None - self.is_vector = ( - True if self.driver.params.inversion_type == "magnetic vector" else False - ) - self.n_blocks = ( - 3 if self.driver.params.inversion_type == "magnetic vector" else 1 - ) + self.is_vector = is_vector + self.trim_active_cells = trim_active_cells self._initialize() def _initialize(self): @@ -418,7 +429,7 @@ def _initialize(self): and self.is_vector and model.shape[0] == self.driver.inversion_mesh.n_cells ): - model = np.tile(model, self.n_blocks) + model = np.tile(model, 3 if self.is_vector else 1) if model is not None: self.model = mkvc(model) @@ -427,8 +438,8 @@ def _initialize(self): def remove_air(self, active_cells): """Use active cells vector to remove air cells from model""" - if self.model is not None: - self.model = self.model[np.tile(active_cells, self.n_blocks)] + if self.model is not None and self.trim_active_cells: + self.model = self.model[np.tile(active_cells, 3 if self.is_vector else 1)] def permute_2_octree(self) -> np.ndarray | None: """ @@ -559,7 +570,7 @@ def model_type(self): @model_type.setter def model_type(self, v): - if v not in self.model_types: - msg = f"Invalid model_type: {v}. Must be one of {(*self.model_types,)}." + if v not in MODEL_TYPES: + msg = f"Invalid model_type: {v}. Must be one of {(*MODEL_TYPES,)}." raise ValueError(msg) self._model_type = v diff --git a/simpeg_drivers/driver.py b/simpeg_drivers/driver.py index 4fe17c2e..1b1cb2a9 100644 --- a/simpeg_drivers/driver.py +++ b/simpeg_drivers/driver.py @@ -14,7 +14,7 @@ from __future__ import annotations import multiprocessing - +from copy import deepcopy import sys from datetime import datetime, timedelta import logging @@ -42,7 +42,13 @@ objective_function, optimization, ) -from simpeg.regularization import BaseRegularization, Sparse + +from simpeg.regularization import ( + BaseRegularization, + RegularizationMesh, + Sparse, + SparseSmoothness, +) from simpeg_drivers import DRIVER_MAP from simpeg_drivers.components import ( @@ -59,6 +65,7 @@ ) from simpeg_drivers.joint.params import BaseJointOptions from simpeg_drivers.utils.utils import tile_locations +from simpeg_drivers.utils.regularization import cell_neighbors, set_rotated_operators mlogger = logging.getLogger("distributed") mlogger.setLevel(logging.WARNING) @@ -440,49 +447,108 @@ def get_regularization(self): return BaseRegularization(mesh=self.inversion_mesh.mesh) reg_funcs = [] + is_rotated = self.params.gradient_rotation is not None + neighbors = None + backward_mesh = None + forward_mesh = None for mapping in self.mapping: - reg = Sparse( - self.inversion_mesh.mesh, - active_cells=self.models.active_cells, + reg_func = Sparse( + forward_mesh or self.inversion_mesh.mesh, + active_cells=self.models.active_cells if forward_mesh is None else None, mapping=mapping, reference_model=self.models.reference, ) + if is_rotated and neighbors is None: + backward_mesh = RegularizationMesh( + self.inversion_mesh.mesh, active_cells=self.models.active_cells + ) + neighbors = cell_neighbors(reg_func.regularization_mesh.mesh) + # Adjustment for 2D versus 3D problems - # TODO check this part - is_2d_reg = ( - "2d" in self.params.inversion_type or "1d" in self.params.inversion_type + components = ( + "sxz" + if ( + "2d" in self.params.inversion_type + or "1d" in self.params.inversion_type + ) + else "sxyz" ) - comps = "sxz" if is_2d_reg else "sxyz" - avg_comps = "sxy" if is_2d_reg else "sxyz" - weights = ["alpha_s"] + [f"length_scale_{k}" for k in comps[1:]] - for comp, avg_comp, objfct, weight in zip( - comps, avg_comps, reg.objfcts, weights + weight_names = ["alpha_s"] + [f"length_scale_{k}" for k in components[1:]] + functions = [] + for comp, weight_name, fun in zip( + components, weight_names, reg_func.objfcts ): - if getattr(self.models, weight) is None: - setattr(reg, weight, 0.0) + if getattr(self.models, weight_name) is None: + setattr(reg_func, weight_name, 0.0) + functions.append(fun) continue - weight = mapping * getattr(self.models, weight) + weight = mapping * getattr(self.models, weight_name) norm = mapping * getattr(self.models, f"{comp}_norm") - if comp in "xyz": - weight = ( - getattr(reg.regularization_mesh, f"aveCC2F{avg_comp}") * weight - ) - norm = getattr(reg.regularization_mesh, f"aveCC2F{avg_comp}") * norm - objfct.set_weights(**{comp: weight}) - objfct.norm = norm + if not isinstance(fun, SparseSmoothness): + fun.set_weights(**{comp: weight}) + fun.norm = norm + functions.append(fun) + continue - if getattr(self.params, "gradient_type") is not None: + if is_rotated: + if forward_mesh is None: + fun = set_rotated_operators( + fun, + neighbors, + comp, + self.models.gradient_dip, + self.models.gradient_direction, + ) + + average_op = getattr( + reg_func.regularization_mesh, + f"aveCC2F{fun.orientation}", + ) + fun.set_weights(**{comp: average_op @ weight}) + fun.norm = average_op @ norm + functions.append(fun) + + if is_rotated: + fun.gradient_type = "components" + backward_fun = deepcopy(fun) + setattr(backward_fun, "_regularization_mesh", backward_mesh) + + # Only do it once for MVI + if not forward_mesh: + backward_fun = set_rotated_operators( + backward_fun, + neighbors, + comp, + self.models.gradient_dip, + self.models.gradient_direction, + forward=False, + ) + average_op = getattr( + backward_fun.regularization_mesh, + f"aveCC2F{fun.orientation}", + ) + backward_fun.set_weights(**{comp: average_op @ weight}) + backward_fun.norm = average_op @ norm + functions.append(backward_fun) + + # Will avoid recomputing operators if the regularization mesh is the same + forward_mesh = reg_func.regularization_mesh + reg_func.objfcts = functions + reg_func.norms = [fun.norm for fun in functions] + reg_funcs.append(reg_func) + + # TODO - To be deprcated on GEOPY-2109 + if getattr(self.params, "gradient_type") is not None: + for reg in reg_funcs: setattr( reg, "gradient_type", getattr(self.params, "gradient_type"), ) - reg_funcs.append(reg) - return objective_function.ComboObjectiveFunction(objfcts=reg_funcs) def get_tiles(self): diff --git a/simpeg_drivers/joint/params.py b/simpeg_drivers/joint/params.py index b93f6953..139e304f 100644 --- a/simpeg_drivers/joint/params.py +++ b/simpeg_drivers/joint/params.py @@ -15,7 +15,7 @@ from geoapps_utils.driver.data import BaseData from geoh5py.data import FloatData -from geoh5py.groups import SimPEGGroup, UIJsonGroup +from geoh5py.groups import PropertyGroup, SimPEGGroup, UIJsonGroup from geoh5py.objects import DrapeModel, Octree from geoh5py.shared.utils import fetch_active_workspace from pydantic import ConfigDict, field_validator, model_validator @@ -69,6 +69,7 @@ class BaseJointOptions(BaseData): length_scale_x: float | FloatData = 1.0 length_scale_y: float | FloatData | None = 1.0 length_scale_z: float | FloatData = 1.0 + gradient_rotation: PropertyGroup | None = None s_norm: float | FloatData | None = 0.0 x_norm: float | FloatData = 2.0 y_norm: float | FloatData | None = 2.0 diff --git a/simpeg_drivers/params.py b/simpeg_drivers/params.py index 2abdc50b..db237ef9 100644 --- a/simpeg_drivers/params.py +++ b/simpeg_drivers/params.py @@ -274,6 +274,7 @@ class BaseInversionOptions(CoreOptions): :param length_scale_x: Length scale x. :param length_scale_y: Length scale y. :param length_scale_z: Length scale z. + :param gradient_rotation: Property group for gradient rotation angles. :param s_norm: S norm. :param x_norm: X norm. @@ -336,6 +337,7 @@ class BaseInversionOptions(CoreOptions): length_scale_x: float | FloatData = 1.0 length_scale_y: float | FloatData | None = 1.0 length_scale_z: float | FloatData = 1.0 + gradient_rotation: PropertyGroup | None = None s_norm: float | FloatData | None = 0.0 x_norm: float | FloatData = 2.0 @@ -368,6 +370,28 @@ class BaseInversionOptions(CoreOptions): percentile: float = 95.0 epsilon_cooling_factor: float = 1.2 + @property + def gradient_dip(self) -> np.ndarray | None: + """Gradient dip angle in clockwise radians from horizontal.""" + if self.gradient_rotation is not None: + dip_uid = self.gradient_rotation.properties[1] + dips = self.geoh5.get_entity(dip_uid)[0].values + return np.deg2rad(dips) + return None + + @property + def gradient_direction(self) -> np.ndarray | None: + """Gradient direction angle in clockwise radians from north""" + if self.gradient_rotation is not None: + from geoh5py.groups.property_group_type import GroupTypeEnum + + direction_uid = self.gradient_rotation.properties[0] + directions = self.geoh5.get_entity(direction_uid)[0].values + if self.gradient_rotation.property_group_type == GroupTypeEnum.STRIKEDIP: + directions += 90.0 + return np.deg2rad(directions) + return None + class EMDataMixin: """ diff --git a/simpeg_drivers/potential_fields/gravity/uijson.py b/simpeg_drivers/potential_fields/gravity/uijson.py index 9957f419..0d0196b9 100644 --- a/simpeg_drivers/potential_fields/gravity/uijson.py +++ b/simpeg_drivers/potential_fields/gravity/uijson.py @@ -108,6 +108,7 @@ class GravityInversionUIJson(SimPEGDriversUIJson): length_scale_x: DataForm length_scale_y: DataForm length_scale_z: DataForm + gradient_rotation: DataForm s_norm: DataForm x_norm: DataForm y_norm: DataForm diff --git a/simpeg_drivers/utils/regularization.py b/simpeg_drivers/utils/regularization.py new file mode 100644 index 00000000..943760cb --- /dev/null +++ b/simpeg_drivers/utils/regularization.py @@ -0,0 +1,434 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2025 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +import numpy as np +import scipy.sparse as ssp +from discretize import TreeMesh +from simpeg.regularization import SparseSmoothness +from simpeg.utils import mkvc, sdiag + + +def cell_neighbors_along_axis(mesh: TreeMesh, axis: str) -> np.ndarray: + """ + Get adjacent cells along provided axis for all cells in the mesh + + :param mesh: Input TreeMesh. + :param axis: Cartesian axis along which to find neighbors. Must be + 'x', 'y', or 'z'. + """ + + if axis not in "xyz": + raise ValueError("Argument 'axis' must be one of 'x', 'y', or 'z'.") + + if isinstance(mesh, TreeMesh): + stencil = getattr(mesh, f"stencil_cell_gradient_{axis}") + else: + stencil = getattr(mesh, f"cell_gradient_{axis}") + + ith_neighbor, jth_neighbor, _ = ssp.find(stencil) + n_stencils = int(ith_neighbor.shape[0] / 2) + stencil_indices = jth_neighbor[np.argsort(ith_neighbor)].reshape((n_stencils, 2)) + + return np.sort(stencil_indices, axis=1) + + +def collect_all_neighbors( + neighbors: list[np.ndarray], + neighbors_backwards: list[np.ndarray], + adjacent: list[np.ndarray], + adjacent_backwards: list[np.ndarray], +) -> np.ndarray: + """ + Collect all neighbors for cells in the mesh. + + :param neighbors: Direct neighbors in each principle axes. + :param neighbors_backwards: Direct neighbors in reverse order. + :param adjacent: Adjacent neighbors (corners). + :param adjacent_backwards: Adjacent neighbors in reverse order. + """ + all_neighbors = [] # Store + + all_neighbors += [neighbors[0]] + all_neighbors += [neighbors[1]] + + all_neighbors += [np.c_[neighbors[0][:, 0], adjacent[0][neighbors[0][:, 1]]]] + all_neighbors += [np.c_[neighbors[0][:, 1], adjacent[0][neighbors[0][:, 0]]]] + + all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 0]], neighbors[1][:, 1]]] + all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 1]], neighbors[1][:, 0]]] + + # Repeat backward for Treemesh + all_neighbors += [neighbors_backwards[0]] + all_neighbors += [neighbors_backwards[1]] + + all_neighbors += [ + np.c_[ + neighbors_backwards[0][:, 0], + adjacent_backwards[0][neighbors_backwards[0][:, 1]], + ] + ] + all_neighbors += [ + np.c_[ + neighbors_backwards[0][:, 1], + adjacent_backwards[0][neighbors_backwards[0][:, 0]], + ] + ] + + # Stack all and keep only unique pairs + all_neighbors = np.vstack(all_neighbors) + all_neighbors = np.unique(all_neighbors, axis=0) + + # Remove all the -1 for TreeMesh + all_neighbors = all_neighbors[ + (all_neighbors[:, 0] != -1) & (all_neighbors[:, 1] != -1), : + ] + + # Use all the neighbours on the xy plane to find neighbours in z + if len(neighbors) == 3: + all_neighbors_z = [] + + all_neighbors_z += [neighbors[2]] + all_neighbors_z += [neighbors_backwards[2]] + + all_neighbors_z += [ + np.c_[all_neighbors[:, 0], adjacent[2][all_neighbors[:, 1]]] + ] + all_neighbors_z += [ + np.c_[all_neighbors[:, 1], adjacent[2][all_neighbors[:, 0]]] + ] + + all_neighbors_z += [ + np.c_[all_neighbors[:, 0], adjacent_backwards[2][all_neighbors[:, 1]]] + ] + all_neighbors_z += [ + np.c_[all_neighbors[:, 1], adjacent_backwards[2][all_neighbors[:, 0]]] + ] + + # Stack all and keep only unique pairs + all_neighbors = np.vstack([all_neighbors, np.vstack(all_neighbors_z)]) + all_neighbors = np.unique(all_neighbors, axis=0) + + # Remove all the -1 for TreeMesh + all_neighbors = all_neighbors[ + (all_neighbors[:, 0] != -1) & (all_neighbors[:, 1] != -1), : + ] + + return all_neighbors + + +def cell_adjacent(neighbors: list[np.ndarray]) -> list[np.ndarray]: + """Find all adjacent cells (corners) from cell neighbor array.""" + + dim = len(neighbors) + max_index = np.max(np.vstack(neighbors)) + corners = -1 * np.ones((dim, max_index + 1), dtype="int") + + corners[0, neighbors[1][:, 0]] = neighbors[1][:, 1] + corners[1, neighbors[0][:, 1]] = neighbors[0][:, 0] + if dim == 3: + corners[2, neighbors[2][:, 0]] = neighbors[2][:, 1] + + return [np.array(k) for k in corners.tolist()] + + +def cell_neighbors(mesh: TreeMesh) -> np.ndarray: + """Find all cell neighbors in a TreeMesh.""" + + neighbors = [] + neighbors.append(cell_neighbors_along_axis(mesh, "x")) + neighbors.append(cell_neighbors_along_axis(mesh, "y")) + if mesh.dim == 3: + neighbors.append(cell_neighbors_along_axis(mesh, "z")) + + neighbors_backwards = [np.fliplr(k) for k in neighbors] + corners = cell_adjacent(neighbors) + corners_backwards = cell_adjacent(neighbors_backwards) + + return collect_all_neighbors( + neighbors, neighbors_backwards, corners, corners_backwards + ) + + +def rotate_xz_2d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix: + """ + Create a 2d ellipsoidal rotation matrix for the xz plane. + + :param mesh: TreeMesh used to adjust angle of rotation to + compensate for cell aspect ratio. + :param phi: Angle in radians for clockwise rotation about the + y-axis (xz plane). + """ + + if mesh.dim != 2: + raise ValueError("Must pass a 2 dimensional mesh.") + + n_cells = len(phi) + hx = mesh.h_gridded[:, 0] + hz = mesh.h_gridded[:, 1] + phi = -np.arctan2((np.sin(phi) / hz), (np.cos(phi) / hx)) + + rza = mkvc(np.c_[np.cos(phi), np.cos(phi)].T) + rzb = mkvc(np.c_[np.sin(phi), np.zeros(n_cells)].T) + rzc = mkvc(np.c_[-np.sin(phi), np.zeros(n_cells)].T) + Ry = ssp.diags([rzb[:-1], rza, rzc[:-1]], [-1, 0, 1]) + + return Ry + + +def rotate_yz_3d(mesh: TreeMesh, theta: np.ndarray) -> ssp.csr_matrix: + """ + Create a 3D ellipsoidal rotation matrix for the yz plane. + + :param mesh: TreeMesh used to adjust angle of rotation to + compensate for cell aspect ratio. + :param theta: Angle in radians for clockwise rotation about the + x-axis (yz plane). + """ + + n_cells = len(theta) + hy = mesh.h_gridded[:, 1] + hz = mesh.h_gridded[:, 2] + theta = -np.arctan2((np.sin(theta) / hz), (np.cos(theta) / hy)) + + rxa = mkvc(np.c_[np.ones(n_cells), np.cos(theta), np.cos(theta)].T) + rxb = mkvc(np.c_[np.zeros(n_cells), np.sin(theta), np.zeros(n_cells)].T) + rxc = mkvc(np.c_[np.zeros(n_cells), -np.sin(theta), np.zeros(n_cells)].T) + Rx = ssp.diags([rxb[:-1], rxa, rxc[:-1]], [-1, 0, 1]) + + return Rx + + +def rotate_xy_3d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix: + """ + Create a 3D ellipsoidal rotation matrix for the xy plane. + + :param mesh: TreeMesh used to adjust angle of rotation to + compensate for cell aspect ratio. + :param phi: Angle in radians for clockwise rotation about the + z-axis (xy plane). + """ + n_cells = len(phi) + hx = mesh.h_gridded[:, 0] + hy = mesh.h_gridded[:, 1] + phi = -np.arctan2((np.sin(phi) / hy), (np.cos(phi) / hx)) + + rza = mkvc(np.c_[np.cos(phi), np.cos(phi), np.ones(n_cells)].T) + rzb = mkvc(np.c_[np.sin(phi), np.zeros(n_cells), np.zeros(n_cells)].T) + rzc = mkvc(np.c_[-np.sin(phi), np.zeros(n_cells), np.zeros(n_cells)].T) + Rz = ssp.diags([rzb[:-1], rza, rzc[:-1]], [-1, 0, 1]) + + return Rz + + +def get_cell_normals(n_cells: int, axis: str, outward: bool) -> np.ndarray: + """ + Returns cell normals for given axis and all cells. + + :param n_cells: Number of cells in the mesh. + :param axis: Cartesian axis (one of 'x', 'y', or 'z' + :param outward: Direction of the normal. True for outward facing, + False for inward facing normals. + """ + + ind = 1 if outward else -1 + + if axis == "x": + normals = np.kron(np.ones(n_cells), np.c_[ind, 0, 0]) + elif axis == "y": + normals = np.kron(np.ones(n_cells), np.c_[0, ind, 0]) + elif axis == "z": + normals = np.kron(np.ones(n_cells), np.c_[0, 0, ind]) + else: + raise ValueError("Axis must be one of 'x', 'y', or 'z'.") + + return normals + + +def get_cell_corners( + mesh: TreeMesh, + neighbors: np.ndarray, + normals: np.ndarray, +) -> list[np.ndarray]: + """ + Return the bottom southwest and top northeast nodes of all cells. + + :param mesh: Input TreeMesh. + :param neighbors: Cell neighbors array. + :param normals: Cell normals array. + """ + + bottom_southwest = ( + mesh.gridCC[neighbors[:, 0], :] + - mesh.h_gridded[neighbors[:, 0], :] / 2 + + normals[neighbors[:, 0], :] * mesh.h_gridded[neighbors[:, 0], :] + ) + top_northeast = ( + mesh.gridCC[neighbors[:, 0], :] + + mesh.h_gridded[neighbors[:, 0], :] / 2 + + normals[neighbors[:, 0], :] * mesh.h_gridded[neighbors[:, 0], :] + ) + + return [bottom_southwest, top_northeast] + + +def get_neighbor_corners(mesh: TreeMesh, neighbors: np.ndarray): + """ + Return the bottom southwest and top northeast corners. + + :param mesh: Input TreeMesh. + :param neighbors: Cell neighbors array. + """ + + bottom_southwest = ( + mesh.gridCC[neighbors[:, 1], :] - mesh.h_gridded[neighbors[:, 1], :] / 2 + ) + top_northeast = ( + mesh.gridCC[neighbors[:, 1], :] + mesh.h_gridded[neighbors[:, 1], :] / 2 + ) + + corners = [bottom_southwest, top_northeast] + + return corners + + +def partial_volumes( + mesh: TreeMesh, neighbors: np.ndarray, normals: np.ndarray +) -> np.ndarray: + """ + Compute partial volumes created by intersecting rotated and unrotated cells. + + :param mesh: Input TreeMesh. + :param neighbors: Cell neighbors array. + :param normals: Cell normals array. + """ + cell_corners = get_cell_corners(mesh, neighbors, normals) + neighbor_corners = get_neighbor_corners(mesh, neighbors) + + volumes = np.ones(neighbors.shape[0]) + for i in range(mesh.dim): + volumes *= np.max( + [ + np.min([neighbor_corners[1][:, i], cell_corners[1][:, i]], axis=0) + - np.max([neighbor_corners[0][:, i], cell_corners[0][:, i]], axis=0), + np.zeros(neighbors.shape[0]), + ], + axis=0, + ) + + # Remove all rows of zero + ind = (volumes > 0) * (neighbors[:, 0] != neighbors[:, 1]) + neighbors = neighbors[ind, :] + volumes = volumes[ind] + + return volumes, neighbors + + +def gradient_operator( + neighbors: np.ndarray, volumes: np.ndarray, n_cells: int +) -> ssp.csr_matrix: + """ + Assemble the sparse gradient operator. + + :param neighbors: Cell neighbor array. + :param volumes: Partial volume array. + :param n_cells: Number of cells in mesh. + """ + Grad = ssp.csr_matrix( + (volumes, (neighbors[:, 0], neighbors[:, 1])), shape=(n_cells, n_cells) + ) + + # Normalize rows + Vol = mkvc(Grad.sum(axis=1)) + Vol[Vol > 0] = 1.0 / Vol[Vol > 0] + Grad = -sdiag(Vol) * Grad + + diag = np.ones(n_cells) + diag[Vol == 0] = 0 + Grad = sdiag(diag) + Grad + + return Grad + + +def rotated_gradient( + mesh: TreeMesh, + neighbors: np.ndarray, + axis: str, + dip: np.ndarray, + direction: np.ndarray, + forward: bool = True, +) -> ssp.csr_matrix: + """ + Calculated rotated gradient operator using partial volumes. + + :param mesh: Input TreeMesh. + :param neighbors: Cell neighbors array. + :param axis: Regularization axis. + :param dip: Angle in radians for rotation from the horizon. + :param direction: Angle in radians for rotation about the z-axis. + :param forward: Whether to use forward or backward difference for + derivative approximations. + """ + + n_cells = mesh.n_cells + if any(len(k) != n_cells for k in [dip, direction]): + raise ValueError( + "Input angle arrays are not the same size as the number of " + "cells in the mesh." + ) + + Rx = rotate_yz_3d(mesh, dip) + Rz = rotate_xy_3d(mesh, direction) + normals = get_cell_normals(n_cells, axis, forward) + rotated_normals = (Rz * (Rx * normals.T)).reshape(n_cells, mesh.dim) + volumes, neighbors = partial_volumes(mesh, neighbors, rotated_normals) + + unit_grad = gradient_operator(neighbors, volumes, n_cells) + return sdiag(1 / mesh.h_gridded[:, "xyz".find(axis)]) @ unit_grad + + +def set_rotated_operators( + function: SparseSmoothness, + neighbors: np.ndarray, + axis: str, + dip: np.ndarray, + direction: np.ndarray, + forward: bool = True, +) -> SparseSmoothness: + """ + Calculated rotated gradient operator using partial volumes. + + :param function: Smoothness regularization to change operator for. + :param neighbors: Cell neighbors array. + :param axis: Regularization axis. + :param dip: Angle in radians for rotation from the horizon. + :param direction: Angle in radians for rotation about the z-axis. + :param forward: Whether to use forward or backward difference for + derivative approximations. + """ + grad_op = rotated_gradient( + function.regularization_mesh.mesh, neighbors, axis, dip, direction, forward + ) + grad_op_active = function.regularization_mesh.Pac.T @ ( + grad_op @ function.regularization_mesh.Pac + ) + active_faces = grad_op_active.max(axis=1).toarray().ravel() > 0 + + setattr( + function.regularization_mesh, + f"_cell_gradient_{function.orientation}", + grad_op_active[active_faces, :], + ) + setattr( + function.regularization_mesh, + f"_aveCC2F{function.orientation}", + sdiag(np.ones(function.regularization_mesh.n_cells))[active_faces, :], + ) + + return function diff --git a/tests/models_test.py b/tests/models_test.py index de1969b3..b7855427 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -72,7 +72,7 @@ def test_zero_reference_model(tmp_path: Path): geoh5 = params.geoh5 with geoh5.open(): driver = MVIInversionDriver(params) - _ = InversionModel(driver, "reference") + _ = InversionModel(driver, "reference", is_vector=True) incl = np.unique(geoh5.get_entity("reference_inclination")[0].values) decl = np.unique(geoh5.get_entity("reference_declination")[0].values) assert len(incl) == 1 @@ -87,7 +87,7 @@ def test_collection(tmp_path: Path): driver = MVIInversionDriver(params) models = InversionModelCollection(driver) models.remove_air(driver.models.active_cells) - starting = InversionModel(driver, "starting") + starting = InversionModel(driver, "starting", is_vector=True) starting.remove_air(driver.models.active_cells) np.testing.assert_allclose(models.starting, starting.model, atol=1e-7) @@ -96,7 +96,7 @@ def test_initialize(tmp_path: Path): params = get_mvi_params(tmp_path) with params.geoh5.open(): driver = MVIInversionDriver(params) - starting_model = InversionModel(driver, "starting") + starting_model = InversionModel(driver, "starting", is_vector=True) assert len(starting_model.model) == 3 * driver.inversion_mesh.n_cells assert len(np.unique(starting_model.model)) == 3 @@ -117,7 +117,7 @@ def test_model_from_object(tmp_path: Path): point_object.add_data({"test_data": {"values": vals}}) data_object = geoh5.get_entity("test_data")[0] params.lower_bound = data_object - lower_bound = InversionModel(driver, "lower_bound") + lower_bound = InversionModel(driver, "lower_bound", is_vector=True) nc = int(len(lower_bound.model) / 3) A = driver.inversion_mesh.mesh.cell_centers b = lower_bound.model[:nc] diff --git a/tests/run_tests/driver_grav_test.py b/tests/run_tests/driver_grav_test.py index 4475dda6..38d1980d 100644 --- a/tests/run_tests/driver_grav_test.py +++ b/tests/run_tests/driver_grav_test.py @@ -34,7 +34,7 @@ # To test the full run and validate the inversion. # Move this file out of the test directory and run. -target_run = {"data_norm": 0.0028055269276044915, "phi_d": 8.32e-05, "phi_m": 0.00333} +target_run = {"data_norm": 0.0028055269276044915, "phi_d": 8.32e-05, "phi_m": 0.0038} def test_gravity_fwr_run( diff --git a/tests/run_tests/driver_rotated_gradients_test.py b/tests/run_tests/driver_rotated_gradients_test.py new file mode 100644 index 00000000..1858258a --- /dev/null +++ b/tests/run_tests/driver_rotated_gradients_test.py @@ -0,0 +1,158 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2025 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import numpy as np +from geoapps_utils.utils.importing import GeoAppsError +from geoh5py.groups.property_group import PropertyGroup +from geoh5py.workspace import Workspace +from pytest import raises + +from simpeg_drivers.params import ActiveCellsOptions +from simpeg_drivers.potential_fields import ( + GravityForwardOptions, + GravityInversionOptions, +) +from simpeg_drivers.potential_fields.gravity.driver import ( + GravityForwardDriver, + GravityInversionDriver, +) +from simpeg_drivers.utils.testing import check_target, setup_inversion_workspace +from simpeg_drivers.utils.utils import get_inversion_output + + +# To test the full run and validate the inversion. +# Move this file out of the test directory and run. + +target_run = {"data_norm": 0.006830937520353864, "phi_d": 0.0276, "phi_m": 0.0288} + + +def test_gravity_rotated_grad_fwr_run( + tmp_path: Path, + n_grid_points=2, + refinement=(2,), +): + # Run the forward + geoh5, _, model, survey, topography = setup_inversion_workspace( + tmp_path, + background=0.0, + anomaly=0.75, + n_electrodes=n_grid_points, + n_lines=n_grid_points, + refinement=refinement, + center=(0.0, 0.0, 15.0), + flatten=False, + ) + + active_cells = ActiveCellsOptions(topography_object=topography) + params = GravityForwardOptions( + geoh5=geoh5, + mesh=model.parent, + active_cells=active_cells, + topography_object=topography, + data_object=survey, + starting_model=model, + gz_channel_bool=True, + ) + fwr_driver = GravityForwardDriver(params) + fwr_driver.run() + + +def test_rotated_grad_run( + tmp_path: Path, + max_iterations=1, + pytest=True, +): + workpath = tmp_path / "inversion_test.ui.geoh5" + if pytest: + workpath = ( + tmp_path.parent + / "test_gravity_rotated_grad_fwr_0" + / "inversion_test.ui.geoh5" + ) + + with Workspace(workpath) as geoh5: + gz = geoh5.get_entity("Iteration_0_gz")[0] + orig_gz = gz.values.copy() + mesh = geoh5.get_entity("mesh")[0] + + # Create property group with orientation + dip = np.ones(mesh.n_cells) * 45 + azimuth = np.ones(mesh.n_cells) * 90 + + data_list = mesh.add_data( + { + "azimuth": {"values": azimuth}, + "dip": {"values": dip}, + } + ) + pg = PropertyGroup( + mesh, properties=data_list, property_group_type="Dip direction & dip" + ) + topography = geoh5.get_entity("topography")[0] + + # Run the inverse + active_cells = ActiveCellsOptions(topography_object=topography) + params = GravityInversionOptions( + geoh5=geoh5, + mesh=mesh, + active_cells=active_cells, + data_object=gz.parent, + gradient_rotation=pg, + starting_model=1e-4, + reference_model=0.0, + s_norm=0.0, + x_norm=0.0, + y_norm=0.0, + z_norm=0.0, + gradient_type="components", + gz_channel=gz, + gz_uncertainty=2e-3, + lower_bound=0.0, + max_global_iterations=max_iterations, + initial_beta_ratio=1e-1, + percentile=95, + store_sensitivities="ram", + save_sensitivities=True, + ) + params.write_ui_json(path=tmp_path / "Inv_run.ui.json") + + driver = GravityInversionDriver.start(str(tmp_path / "Inv_run.ui.json")) + + with Workspace(driver.params.geoh5.h5file) as run_ws: + output = get_inversion_output( + driver.params.geoh5.h5file, driver.params.out_group.uid + ) + output["data"] = orig_gz + + if pytest: + check_target(output, target_run) + nan_ind = np.isnan(run_ws.get_entity("Iteration_0_model")[0].values) + inactive_ind = run_ws.get_entity("active_cells")[0].values == 0 + assert np.all(nan_ind == inactive_ind) + + +if __name__ == "__main__": + # Full run + test_gravity_rotated_grad_fwr_run( + Path("./"), + n_grid_points=10, + refinement=(4, 8), + ) + + test_rotated_grad_run( + Path("./"), + max_iterations=40, + pytest=False, + ) diff --git a/tests/utils_regularization_test.py b/tests/utils_regularization_test.py new file mode 100644 index 00000000..d0f0fefe --- /dev/null +++ b/tests/utils_regularization_test.py @@ -0,0 +1,92 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2025 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +import numpy as np +from discretize import TreeMesh + +from simpeg_drivers.utils.regularization import ( + cell_adjacent, + cell_neighbors_along_axis, + collect_all_neighbors, +) + + +def get_mesh(): + mesh = TreeMesh(h=[[10.0] * 4, [10.0] * 4, [10.0] * 4], diagonal_balance=False) + mesh.refine(2) + return mesh + + +def test_cell_neighbors_along_axis(): + mesh = get_mesh() + centers = mesh.cell_centers + neighbors = cell_neighbors_along_axis(mesh, "x") + assert np.allclose(centers[7], [15.0, 15.0, 15.0]) + assert np.allclose( + centers[neighbors[neighbors[:, 0] == 7][0][1]], [25.0, 15.0, 15.0] + ) + assert np.allclose( + centers[neighbors[neighbors[:, 1] == 7][0][0]], [5.0, 15.0, 15.0] + ) + neighbors = cell_neighbors_along_axis(mesh, "y") + assert np.allclose( + centers[neighbors[neighbors[:, 0] == 7][0][1]], [15.0, 25.0, 15.0] + ) + assert np.allclose( + centers[neighbors[neighbors[:, 1] == 7][0][0]], [15.0, 5.0, 15.0] + ) + neighbors = cell_neighbors_along_axis(mesh, "z") + assert np.allclose( + centers[neighbors[neighbors[:, 0] == 7][0][1]], [15.0, 15.0, 25.0] + ) + assert np.allclose( + centers[neighbors[neighbors[:, 1] == 7][0][0]], [15.0, 15.0, 5.0] + ) + + +def test_collect_all_neighbors(): + mesh = get_mesh() + centers = mesh.cell_centers + neighbors = [cell_neighbors_along_axis(mesh, k) for k in "xyz"] + neighbors_bck = [np.fliplr(k) for k in neighbors] + corners = cell_adjacent(neighbors) + corners_bck = cell_adjacent(neighbors_bck) + all_neighbors = collect_all_neighbors( + neighbors, neighbors_bck, corners, corners_bck + ) + assert np.allclose(centers[7], [15.0, 15.0, 15.0]) + neighbor_centers = centers[all_neighbors[all_neighbors[:, 0] == 7][:, 1]].tolist() + assert [5, 5, 5] in neighbor_centers + assert [15, 5, 5] in neighbor_centers + assert [25, 5, 5] in neighbor_centers + assert [5, 15, 5] in neighbor_centers + assert [15, 15, 5] in neighbor_centers + assert [25, 15, 5] in neighbor_centers + assert [5, 25, 5] in neighbor_centers + assert [15, 25, 5] in neighbor_centers + assert [25, 25, 5] in neighbor_centers + assert [5, 5, 15] in neighbor_centers + assert [15, 5, 15] in neighbor_centers + assert [25, 5, 15] in neighbor_centers + assert [5, 15, 15] in neighbor_centers + assert [25, 15, 15] in neighbor_centers + assert [5, 25, 15] in neighbor_centers + assert [15, 25, 15] in neighbor_centers + assert [25, 25, 15] in neighbor_centers + assert [5, 5, 25] in neighbor_centers + assert [15, 5, 25] in neighbor_centers + assert [25, 5, 25] in neighbor_centers + assert [5, 15, 25] in neighbor_centers + assert [15, 15, 25] in neighbor_centers + assert [25, 15, 25] in neighbor_centers + assert [5, 25, 25] in neighbor_centers + assert [15, 25, 25] in neighbor_centers + assert [25, 25, 25] in neighbor_centers + assert [15, 15, 15] not in neighbor_centers