diff --git a/simpeg_drivers-assets/MT_fwr_UBC_example.geoh5 b/simpeg_drivers-assets/MT_fwr_UBC_example.geoh5 new file mode 100644 index 00000000..9bd051c0 --- /dev/null +++ b/simpeg_drivers-assets/MT_fwr_UBC_example.geoh5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd94d02140664dfa589048fd8bf13a8d2af97d4905a2073d7cef60beef7b1fc5 +size 1477617 diff --git a/simpeg_drivers/components/data.py b/simpeg_drivers/components/data.py index 588d125d..bac0e63b 100644 --- a/simpeg_drivers/components/data.py +++ b/simpeg_drivers/components/data.py @@ -16,20 +16,15 @@ from typing import TYPE_CHECKING, Any import numpy as np -from discretize import TensorMesh, TreeMesh -from geoh5py.objects import PotentialElectrode +from geoh5py.objects import LargeLoopGroundTEMReceivers, PotentialElectrode from scipy.sparse import csgraph, csr_matrix from scipy.spatial import cKDTree -from simpeg import maps from simpeg.electromagnetics.static.utils.static_utils import geometric_factor -from simpeg.simulation import BaseSimulation -from simpeg_drivers.utils.utils import create_nested_mesh, drape_2_tensor +from simpeg_drivers.utils.utils import drape_2_tensor from .factories import ( EntityFactory, - SaveDataGeoh5Factory, - SimulationFactory, SurveyFactory, ) from .locations import InversionLocations @@ -54,8 +49,6 @@ class InversionData(InversionLocations): mask : Mask accumulated by windowing and downsampling operations and applied to locations and data on initialization. - n_blocks : - Number of blocks if vector. components : Component names. observed : @@ -83,7 +76,6 @@ def __init__(self, workspace: Workspace, params: InversionBaseOptions): super().__init__(workspace, params) self.locations: np.ndarray | None = None self.mask: np.ndarray | None = None - self.n_blocks: int | None = None self._observed: dict[str, np.ndarray] | None = None self._uncertainties: dict[str, np.ndarray] | None = None @@ -97,7 +89,6 @@ def __init__(self, workspace: Workspace, params: InversionBaseOptions): def _initialize(self) -> None: """Extract data from the workspace using params data.""" - self.n_blocks = 3 if self.params.inversion_type == "magnetic vector" else 1 self.components = self.params.active_components self.has_tensor = InversionData.check_tensor(self.params.components) @@ -160,6 +151,9 @@ def parts(self): connections = csgraph.connected_components(edge_array)[1] return connections[self.entity.cells[:, 0]] + if isinstance(self.entity, LargeLoopGroundTEMReceivers): + return self.entity.tx_id_property.values + return getattr(self.entity, "parts", None) def drape_locations(self, locations: np.ndarray) -> np.ndarray: @@ -232,10 +226,6 @@ def save_data(self): if channels is None: continue - # Non-EM methods - if not has_channels: - channels = {None: channels} - for ind, (channel, values) in enumerate(channels.items()): suffix = f"_{component}" if has_channels: @@ -333,126 +323,31 @@ def get_normalizations(self): return normalizations - def create_survey( - self, - local_index: np.ndarray | None = None, - channel=None, - ): + def create_survey(self): """ Generates SimPEG survey object. - :param: local_index (Optional): Indices of the data belonging to a - particular tile in case of a tiled inversion. - :return: survey: SimPEG Survey class that covers all data or optionally the portion of the data indexed by the local_index argument. - :return: local_index: receiver indices belonging to a particular tile. """ survey_factory = SurveyFactory(self.params) - survey, local_index, ordering = survey_factory.build( - data=self, - local_index=local_index, - channel=channel, - ) + survey = survey_factory.build(data=self) + survey.ordering = survey_factory.ordering + survey.sorting = survey_factory.sorting + survey.locations = self.entity.vertices + # Save apparent resistivity in geoh5 order if "direct current" in self.params.inversion_type: survey.apparent_resistivity = 1 / ( - geometric_factor(survey)[np.argsort(np.hstack(local_index))] + 1e-10 - ) - - return survey, local_index, ordering - - def simulation( - self, - inversion_mesh: InversionMesh, - local_mesh: TreeMesh | TensorMesh | None, - active_cells: np.ndarray, - survey, - tile_id: int | None = None, - padding_cells: int = 6, - ) -> tuple[BaseSimulation, maps.IdentityMap]: - """ - Generates SimPEG simulation object. - - :param: mesh: inversion mesh. - :param: active_cells: Mask that reduces model to active (earth) cells. - :param: survey: SimPEG survey object. - :param: tile_id (Optional): Id associated with the tile covered by - the survey in case of a tiled inversion. - - :return: sim: SimPEG simulation object for full data or optionally - the portion of the data indexed by the local_index argument. - :return: map: If local_index and tile_id is provided, the returned - map will maps from local to global data. If no local_index or - tile_id is provided map will simply be an identity map with no - effect of the data. - """ - simulation_factory = SimulationFactory(self.params) - - if tile_id is None or "2d" in self.params.inversion_type: - mapping = maps.IdentityMap(nP=int(self.n_blocks * active_cells.sum())) - simulation = simulation_factory.build( - survey=survey, - global_mesh=inversion_mesh.mesh, - active_cells=active_cells, - mapping=mapping, - ) - elif "1d" in self.params.inversion_type: - slice_ind = np.arange( - tile_id, inversion_mesh.mesh.n_cells, inversion_mesh.mesh.shape_cells[0] - )[::-1] - mapping = maps.Projection(inversion_mesh.mesh.n_cells, slice_ind) - simulation = simulation_factory.build( - survey=survey, - receivers=self.entity, - global_mesh=inversion_mesh.mesh, - local_mesh=inversion_mesh.layers_mesh, - active_cells=active_cells, - mapping=mapping, - tile_id=tile_id, - ) - else: - if local_mesh is None: - local_mesh = create_nested_mesh( - survey, - inversion_mesh.mesh, - minimum_level=3, - padding_cells=padding_cells, - ) - mapping = maps.TileMap( - inversion_mesh.mesh, - active_cells, - local_mesh, - enforce_active=True, - components=self.n_blocks, - ) - simulation = simulation_factory.build( - survey=survey, - receivers=self.entity, - global_mesh=inversion_mesh.mesh, - local_mesh=local_mesh, - active_cells=mapping.local_active, - mapping=mapping, - tile_id=tile_id, + geometric_factor(survey)[np.argsort(survey.sorting)] + 1e-10 ) + survey.cells = self.entity.cells - return simulation, mapping - - def simulate(self, model, inverse_problem, sorting, ordering): - """Simulate fields for a particular model.""" - dpred = inverse_problem.get_dpred( - model, compute_J=False if self.params.forward_only else True - ) - if self.params.forward_only: - save_directive = SaveDataGeoh5Factory(self.params).build( - inversion_object=self, - sorting=np.argsort(np.hstack(sorting)), - ordering=ordering, - ) - save_directive.write(0, dpred) + if "induced polarization" in self.params.inversion_type: + survey.cells = self.entity.cells - inverse_problem.dpred = dpred + return survey @property def observed_data_types(self): @@ -495,7 +390,7 @@ def update_params(self, data_dict, uncert_dict): @property def survey(self): if self._survey is None: - self._survey, _, _ = self.create_survey() + self._survey = self.create_survey() return self._survey diff --git a/simpeg_drivers/components/factories/directives_factory.py b/simpeg_drivers/components/factories/directives_factory.py index f76cd0e9..6301d2ff 100644 --- a/simpeg_drivers/components/factories/directives_factory.py +++ b/simpeg_drivers/components/factories/directives_factory.py @@ -20,6 +20,9 @@ import numpy as np from geoh5py.groups.property_group import GroupTypeEnum +from geoh5py.objects.surveys.electromagnetics.base import FEMSurvey +from geoh5py.objects.surveys.electromagnetics.magnetotellurics import MTReceivers +from geoh5py.objects.surveys.electromagnetics.tipper import TipperReceivers from numpy import sqrt from simpeg import directives, maps from simpeg.utils.mat_utils import cartesian2amplitude_dip_azimuth @@ -180,8 +183,6 @@ def save_iteration_apparent_resistivity_directive(self): self.params ).build( inversion_object=self.driver.inversion_data, - active_cells=self.driver.models.active_cells, - sorting=np.argsort(np.hstack(self.driver.sorting)), name="Apparent Resistivity", ) return self._save_iteration_apparent_resistivity_directive @@ -223,10 +224,6 @@ def save_iteration_data_directive(self): self.params ).build( inversion_object=self.driver.inversion_data, - active_cells=self.driver.models.active_cells, - sorting=np.argsort(np.hstack(self.driver.sorting)), - ordering=self.driver.ordering, - global_misfit=self.driver.data_misfit, name="Data", ) return self._save_iteration_data_directive @@ -265,9 +262,6 @@ def save_iteration_residual_directive(self): self.params ).build( inversion_object=self.driver.inversion_data, - active_cells=self.driver.models.active_cells, - sorting=np.argsort(np.hstack(self.driver.sorting)), - ordering=self.driver.ordering, name="Residual", ) return self._save_iteration_residual_directive @@ -366,8 +360,6 @@ def assemble_arguments( self, inversion_object=None, active_cells=None, - sorting=None, - ordering=None, transform=None, global_misfit=None, name=None, @@ -386,8 +378,6 @@ def assemble_keyword_arguments( self, inversion_object=None, active_cells=None, - sorting=None, - ordering=None, transform=None, global_misfit=None, name=None, @@ -450,8 +440,6 @@ def assemble_keyword_arguments( self, inversion_object=None, active_cells=None, - sorting=None, - ordering=None, transform=None, global_misfit=None, name=None, @@ -497,103 +485,67 @@ class SaveDataGeoh5Factory(SaveGeoh5Factory): def assemble_keyword_arguments( self, inversion_object=None, - active_cells=None, - sorting=None, - ordering=None, - transform=None, - global_misfit=None, name=None, ): - if self.factory_type in [ - "fdem", - "fdem 1d", - "tdem", - "tdem 1d", - "magnetotellurics", - "tipper", - ]: - kwargs = self.assemble_data_keywords_em( - inversion_object=inversion_object, - active_cells=active_cells, - sorting=sorting, - ordering=ordering, - transform=transform, - global_misfit=global_misfit, - name=name, - ) + receivers = inversion_object.entity + channels = getattr(receivers, "channels", [None]) + components = list(inversion_object.observed) + ordering = inversion_object.survey.ordering + n_locations = len(np.unique(ordering[:, 2])) + + def reshape(values): + data = np.zeros((len(channels), len(components), n_locations)) + data[ordering[:, 0], ordering[:, 1], ordering[:, 2]] = values + return data - elif self.factory_type in [ + kwargs = { + "data_type": inversion_object.observed_data_types, + "association": "VERTEX", + "transforms": [ + np.hstack( + [ + 1 / inversion_object.normalizations[chan][comp] + for chan in channels + for comp in components + ], + ), + ], + "channels": channels, + "components": components, + "reshape": reshape, + } + + if self.factory_type in [ "direct current 3d", "direct current 2d", "induced polarization 3d", "induced polarization 2d", ]: kwargs = self.assemble_data_keywords_dcip( - inversion_object=inversion_object, - active_cells=active_cells, - sorting=sorting, - transform=transform, - global_misfit=global_misfit, - name=name, + inversion_object=inversion_object, name=name, **kwargs ) elif self.factory_type in ["gravity", "magnetic scalar", "magnetic vector"]: kwargs = self.assemble_data_keywords_potential_fields( inversion_object=inversion_object, - active_cells=active_cells, - sorting=sorting, - transform=transform, - global_misfit=global_misfit, name=name, + **kwargs, ) - else: - return None - - if transform is not None: - kwargs["transforms"].append(transform) return kwargs @staticmethod def assemble_data_keywords_potential_fields( inversion_object=None, - active_cells=None, - sorting=None, - transform=None, - global_misfit=None, name=None, + **kwargs, ): - components = list(inversion_object.observed) - channels = [None] - kwargs = { - "data_type": inversion_object.observed_data_types, - "transforms": [ - np.hstack( - [ - inversion_object.normalizations[chan][comp] - for chan in channels - for comp in components - ] - ) - ], - "channels": channels, - "components": components, - "association": "VERTEX", - "reshape": lambda x: x.reshape( - (len(channels), len(components), -1), order="F" - ), - } - - if sorting is not None: - kwargs["sorting"] = np.hstack(sorting) - if name == "Residual": kwargs["label"] = name data = inversion_object.normalize(inversion_object.observed) def potfield_transform(x): - data_stack = np.row_stack(list(data.values())) - data_stack = data_stack[:, np.argsort(sorting)] + data_stack = np.vstack([k[None] for k in data.values()]) return data_stack.ravel() - x kwargs.pop("data_type") @@ -604,48 +556,22 @@ def potfield_transform(x): def assemble_data_keywords_dcip( self, inversion_object=None, - active_cells=None, - sorting=None, - transform=None, - global_misfit=None, name=None, + **kwargs, ): components = list(inversion_object.observed) - channels = [None] - is_dc = True if "direct current" in self.factory_type else False - component = "dc" if is_dc else "ip" - kwargs = { - "data_type": inversion_object.observed_data_types, - "transforms": [ - np.hstack( - [ - inversion_object.normalizations[chan][comp] - for chan in channels - for comp in components - ] - ) - ], - "channels": channels, - "components": [component], - "reshape": lambda x: x.reshape( - (len(channels), len(components), -1), order="F" - ), - "association": "CELL", - } + kwargs["association"] = "CELL" - if sorting is not None: - kwargs["sorting"] = np.hstack(sorting) - - if is_dc and name == "Apparent Resistivity": + if "direct current" in self.factory_type and name == "Apparent Resistivity": kwargs["transforms"].insert( 0, - inversion_object.survey.apparent_resistivity[np.argsort(sorting)], + inversion_object.survey.apparent_resistivity, ) kwargs["channels"] = ["apparent_resistivity"] observed = self.params.geoh5.get_entity("Observed_apparent_resistivity")[0] if observed is not None: kwargs["data_type"] = { - component: {"apparent_resistivity": observed.entity_type} + components[0]: {"apparent_resistivity": observed.entity_type} } if name == "Residual": @@ -653,52 +579,10 @@ def assemble_data_keywords_dcip( data = inversion_object.normalize(inversion_object.observed) def dcip_transform(x): - data_stack = np.row_stack(list(data.values())).ravel() - sorting_stack = np.tile(np.argsort(sorting), len(data)) - return data_stack[sorting_stack] - x + data_stack = np.vstack([k[None] for k in data.values()]) + return data_stack.ravel() - x kwargs["transforms"].insert(0, dcip_transform) kwargs.pop("data_type") return kwargs - - def assemble_data_keywords_em( - self, - inversion_object=None, - active_cells=None, - sorting=None, - ordering=None, - transform=None, - global_misfit=None, - name=None, - ): - receivers = inversion_object.entity - channels = np.array(receivers.channels, dtype=float) - components = list(inversion_object.observed) - ordering = np.vstack(ordering) - channel_ids = ordering[:, 0] - component_ids = ordering[:, 1] - rx_ids = ordering[:, 2] - - def reshape(values): - data = np.zeros((len(channels), len(components), receivers.n_vertices)) - data[channel_ids, component_ids, rx_ids] = values - return data - - kwargs = { - "data_type": inversion_object.observed_data_types, - "association": "VERTEX", - "transforms": np.hstack( - [ - 1 / inversion_object.normalizations[chan][comp] - for chan in channels - for comp in components - ] - ), - "channels": channels, - "components": components, - "sorting": sorting, - "_reshape": reshape, - } - - return kwargs diff --git a/simpeg_drivers/components/factories/misfit_factory.py b/simpeg_drivers/components/factories/misfit_factory.py index ae3d18f6..439b3e0d 100644 --- a/simpeg_drivers/components/factories/misfit_factory.py +++ b/simpeg_drivers/components/factories/misfit_factory.py @@ -13,254 +13,90 @@ from typing import TYPE_CHECKING +import numpy as np +from simpeg import objective_function +from simpeg.simulation import BaseSimulation + +from simpeg_drivers.components.factories.simpeg_factory import SimPEGFactory +from simpeg_drivers.utils.nested import create_misfit + if TYPE_CHECKING: from geoapps_utils.driver.params import BaseParams - from simpeg_drivers.components.data import InversionData - from simpeg_drivers.components.meshes import InversionMesh from simpeg_drivers.options import BaseOptions -import numpy as np -from geoh5py.objects import Octree -from simpeg import data, data_misfit, maps, meta, objective_function - -from simpeg_drivers.components.factories.simpeg_factory import SimPEGFactory - class MisfitFactory(SimPEGFactory): """Build SimPEG global misfit function.""" - def __init__(self, params: BaseParams | BaseOptions, models=None): + def __init__(self, params: BaseParams | BaseOptions, simulation: BaseSimulation): """ :param params: Options object containing SimPEG object parameters. """ super().__init__(params) self.simpeg_object = self.concrete_object() self.factory_type = self.params.inversion_type - self.models = models - self.sorting = None - self.ordering = None + self.simulation = simulation def concrete_object(self): return objective_function.ComboObjectiveFunction - def build(self, tiles, split_list, inversion_data, inversion_mesh, active_cells): # pylint: disable=arguments-differ + def build(self, tiles, split_list): # pylint: disable=arguments-differ global_misfit = super().build( tiles=tiles, split_list=split_list, - inversion_data=inversion_data, - inversion_mesh=inversion_mesh, - active_cells=active_cells, ) - return global_misfit, self.sorting, self.ordering + return global_misfit def assemble_arguments( # pylint: disable=arguments-differ self, tiles, split_list, - inversion_data, - inversion_mesh, - active_cells, ): # Base slice over frequencies if self.factory_type in ["magnetotellurics", "tipper", "fdem"]: - channels = inversion_data.entity.channels + channels = self.simulation.survey.frequencies else: channels = [None] - local_misfits = [] - - self.sorting = [] - self.ordering = [] - tile_count = 0 - data_count = 0 - misfit_count = 0 - for local_index in tiles: - if len(local_index) == 0: - continue - - local_sim, _, _, _ = self.create_nested_simulation( - inversion_data, - inversion_mesh, - None, - active_cells, - local_index, - channel=None, - tile_id=tile_count, - padding_cells=self.params.padding_cells, - ) - - local_mesh = getattr(local_sim, "mesh", None) - - for count, channel in enumerate(channels): - n_split = split_list[misfit_count] - for split_ind in np.array_split(local_index, n_split): - local_sim, split_ind, ordering, mapping = ( - self.create_nested_simulation( - inversion_data, - inversion_mesh, - local_mesh, - active_cells, - split_ind, - channel=channel, - tile_id=tile_count, - padding_cells=self.params.padding_cells, - ) - ) - - if count == 0: - if self.factory_type in [ - "fdem", - "tdem", - "magnetotellurics", - "tipper", - ]: - self.sorting.append( - np.arange( - data_count, - data_count + len(split_ind), - dtype=int, - ) - ) - data_count += len(split_ind) - else: - self.sorting.append(split_ind) - - # TODO this should be done in the simulation factory - if "induced polarization" in self.params.inversion_type: - if "2d" in self.params.inversion_type: - proj = maps.InjectActiveCells( - inversion_mesh.mesh, active_cells, value_inactive=1e-8 - ) - else: - proj = maps.InjectActiveCells( - mapping.local_mesh, - mapping.local_active, - value_inactive=1e-8, - ) - - local_sim.sigma = ( - proj * mapping * self.models.conductivity_model - ) + futures = [] + # TODO bring back on GEOPY-2182 + # with ProcessPoolExecutor() as executor: + count = 0 + for channel in channels: + tile_count = 0 + for local_indices in tiles: + if len(local_indices) == 0: + continue - simulation = meta.MetaSimulation( - simulations=[local_sim], mappings=[mapping] + n_split = split_list[count] + futures.append( + # executor.submit( + create_misfit( + self.simulation, + local_indices, + channel, + tile_count, + n_split, + self.params.padding_cells, + self.params.inversion_type, + self.params.forward_only, ) + ) + tile_count += np.sum(n_split) + count += 1 - local_data = data.Data(local_sim.survey) - - if self.params.forward_only: - lmisfit = data_misfit.L2DataMisfit(local_data, simulation) - - else: - local_data.dobs = local_sim.survey.dobs - local_data.standard_deviation = local_sim.survey.std - lmisfit = data_misfit.L2DataMisfit( - local_data, - simulation, - ) - - name = self.params.inversion_type - - if len(tiles) > 1 or n_split > 1: - name += f": Tile {tile_count + 1}" - if len(channels) > 1: - name += f": Channel {channel}" - - lmisfit.name = f"{name}" - - local_misfits.append(lmisfit) - self.ordering.append(ordering) - - tile_count += 1 - - misfit_count += 1 + local_misfits = [] + local_orderings = [] + for future in futures: # as_completed(futures): + misfits, orderings = future # future.result() + local_misfits += misfits + local_orderings += orderings + self.simulation.survey.ordering = np.vstack(local_orderings) return [local_misfits] def assemble_keyword_arguments(self, **_): """Implementation of abstract method from SimPEGFactory.""" return {} - - @staticmethod - def create_nested_simulation( - inversion_data: InversionData, - inversion_mesh: InversionMesh, - local_mesh: Octree | None, - active_cells: np.ndarray, - indices: np.ndarray, - *, - channel: int | None = None, - tile_id: int | None = None, - padding_cells=100, - ): - """ - Generate a survey, mesh and simulation based on indices. - - :param inversion_data: InversionData object. - :param mesh: Octree mesh. - :param active_cells: Active cell model. - :param indices: Indices of receivers belonging to the tile. - :param channel: Channel number for frequency or time channels. - :param tile_id: Tile id stored on the simulation. - :param padding_cells: Number of padding cells around the local survey. - """ - survey, indices, ordering = inversion_data.create_survey( - local_index=indices, channel=channel - ) - local_sim, mapping = inversion_data.simulation( - inversion_mesh, - local_mesh, - active_cells, - survey, - tile_id=tile_id, - padding_cells=padding_cells, - ) - inv_type = inversion_data.params.inversion_type - if inv_type in ["fdem", "tdem"]: - compute_em_projections(inversion_data, local_sim) - elif ("current" in inv_type or "polarization" in inv_type) and ( - "2d" not in inv_type or "pseudo" in inv_type - ): - compute_dc_projections(inversion_data, local_sim, indices) - return local_sim, np.hstack(indices), ordering, mapping - - -def compute_em_projections(inversion_data, simulation): - """ - Pre-compute projections for the receivers for efficiency. - """ - rx_locs = inversion_data.entity.vertices - projections = {} - for component in "xyz": - projections[component] = simulation.mesh.get_interpolation_matrix( - rx_locs, "faces_" + component[0] - ) - - for source in simulation.survey.source_list: - for receiver in source.receiver_list: - projection = 0.0 - for orientation, comp in zip(receiver.orientation, "xyz", strict=True): - if orientation == 0: - continue - projection += orientation * projections[comp][receiver.local_index, :] - receiver.spatialP = projection - - -def compute_dc_projections(inversion_data, simulation, indices): - """ - Pre-compute projections for the receivers for efficiency. - """ - rx_locs = inversion_data.entity.vertices - mn_pairs = inversion_data.entity.cells - projection = simulation.mesh.get_interpolation_matrix(rx_locs, "nodes") - - for source, ind in zip(simulation.survey.source_list, indices, strict=True): - proj_mn = projection[mn_pairs[ind, 0], :] - - # Check if dipole receiver - if not np.all(mn_pairs[ind, 0] == mn_pairs[ind, 1]): - proj_mn -= projection[mn_pairs[ind, 1], :] - - source.receiver_list[0].spatialP = proj_mn # pylint: disable=protected-access diff --git a/simpeg_drivers/components/factories/receiver_factory.py b/simpeg_drivers/components/factories/receiver_factory.py index 1649ad25..69bbaa85 100644 --- a/simpeg_drivers/components/factories/receiver_factory.py +++ b/simpeg_drivers/components/factories/receiver_factory.py @@ -118,7 +118,7 @@ def assemble_arguments( ) else: - args.append(locations[local_index]) + args.append(locations) return args @@ -193,9 +193,4 @@ def _tdem_arguments(self, data=None, locations=None, local_index=None): ] def _magnetotellurics_arguments(self, locations=None, local_index=None): - args = [] - locs = locations[local_index] - - args.append(locs) - - return args + return [locations] diff --git a/simpeg_drivers/components/factories/simulation_factory.py b/simpeg_drivers/components/factories/simulation_factory.py index d15e6559..17b94873 100644 --- a/simpeg_drivers/components/factories/simulation_factory.py +++ b/simpeg_drivers/components/factories/simulation_factory.py @@ -120,35 +120,22 @@ def concrete_object(self): def assemble_arguments( self, survey=None, - receivers=None, - global_mesh=None, - local_mesh=None, - active_cells=None, - mapping=None, - tile_id=None, + mesh=None, + models=None, ): if "1d" in self.factory_type: return () - mesh = global_mesh if tile_id is None else local_mesh return [mesh] def assemble_keyword_arguments( self, survey=None, - receivers=None, - global_mesh=None, - local_mesh=None, - active_cells=None, - mapping=None, - tile_id=None, + mesh=None, + models=None, ): - mesh = global_mesh if tile_id is None else local_mesh - sensitivity_path = self._get_sensitivity_path(tile_id) - kwargs = {} kwargs["survey"] = survey - kwargs["sensitivity_path"] = sensitivity_path kwargs["max_chunk_size"] = self.params.compute.max_chunk_size kwargs["store_sensitivities"] = ( "forward_only" @@ -156,28 +143,31 @@ def assemble_keyword_arguments( else self.params.store_sensitivities ) kwargs["solver"] = self.solver - + active_cells = models.active_cells if self.factory_type == "magnetic vector": kwargs["active_cells"] = active_cells kwargs["chiMap"] = maps.IdentityMap(nP=int(active_cells.sum()) * 3) kwargs["model_type"] = "vector" - kwargs["chunk_format"] = "row" if self.factory_type == "magnetic scalar": kwargs["active_cells"] = active_cells kwargs["chiMap"] = maps.IdentityMap(nP=int(active_cells.sum())) - kwargs["chunk_format"] = "row" if self.factory_type == "gravity": kwargs["active_cells"] = active_cells kwargs["rhoMap"] = maps.IdentityMap(nP=int(active_cells.sum())) - kwargs["chunk_format"] = "row" if "induced polarization" in self.factory_type: etamap = maps.InjectActiveCells( mesh, active_cells=active_cells, value_inactive=0 ) kwargs["etaMap"] = etamap + kwargs["sigma"] = ( + maps.InjectActiveCells( + mesh, active_cells=active_cells, value_inactive=1e-8 + ) + * models.conductivity_model + ) if self.factory_type in [ "direct current 3d", @@ -193,16 +183,12 @@ def assemble_keyword_arguments( kwargs["sigmaMap"] = maps.ExpMap(mesh) * actmap if "tdem" in self.factory_type: - kwargs["t0"] = -receivers.timing_mark * self.params.unit_conversion - kwargs["time_steps"] = ( - np.round((np.diff(np.unique(receivers.waveform[:, 0]))), decimals=6) - * self.params.unit_conversion - ) + kwargs["t0"] = -self.params.timing_mark + kwargs["time_steps"] = self.params.time_steps if "1d" in self.factory_type: kwargs["sigmaMap"] = maps.ExpMap(mesh) - kwargs["thicknesses"] = local_mesh.h[0][1:][::-1] - kwargs["topo"] = active_cells[tile_id] + kwargs["thicknesses"] = mesh.h[1][1:][::-1] return kwargs diff --git a/simpeg_drivers/components/factories/survey_factory.py b/simpeg_drivers/components/factories/survey_factory.py index 5ba1b035..b98c47cd 100644 --- a/simpeg_drivers/components/factories/survey_factory.py +++ b/simpeg_drivers/components/factories/survey_factory.py @@ -13,6 +13,7 @@ from __future__ import annotations +from gc import is_finalized from typing import TYPE_CHECKING @@ -23,6 +24,7 @@ import numpy as np import simpeg.electromagnetics.time_domain as tdem +from geoh5py.objects.surveys.electromagnetics.airborne_fem import AirborneFEMReceivers from geoh5py.objects.surveys.electromagnetics.ground_tem import ( LargeLoopGroundTEMTransmitters, ) @@ -47,6 +49,7 @@ def __init__(self, params: BaseParams | BaseOptions): self.local_index = None self.survey = None self.ordering = None + self.sorting = None def concrete_object(self): if self.factory_type in ["magnetic vector", "magnetic scalar"]: @@ -75,35 +78,32 @@ def concrete_object(self): return survey.Survey - def assemble_arguments(self, data=None, local_index=None, channel=None): + def assemble_arguments(self, data=None): """Provides implementations to assemble arguments for receivers object.""" - receiver_entity = data.entity - - if local_index is None: - if "current" in self.factory_type or "polarization" in self.factory_type: - n_data = receiver_entity.n_cells - else: - n_data = receiver_entity.n_vertices - - self.local_index = np.arange(n_data) - else: - self.local_index = local_index - if "current" in self.factory_type or "polarization" in self.factory_type: - return self._dcip_arguments(data=data, local_index=local_index) + return self._dcip_arguments(data=data) elif "tdem" in self.factory_type: return self._tdem_arguments(data=data) elif self.factory_type in ["magnetotellurics", "tipper"]: - return self._naturalsource_arguments(data=data, frequency=channel) + return self._naturalsource_arguments(data=data) elif "fdem" in self.factory_type: - return self._fem_arguments(data=data, channel=channel) - else: + return self._fem_arguments(data=data) + else: # Gravity and Magnetic receivers = ReceiversFactory(self.params).build( locations=data.locations, data=data.observed, - local_index=self.local_index, ) sources = SourcesFactory(self.params).build(receivers=receivers) + n_rx = data.locations.shape[0] + sources.rx_ids = np.arange(n_rx, dtype=int) + n_comp = len(data.components) + self.ordering = np.c_[ + np.zeros(n_rx * n_comp), # Single channel + np.kron(np.ones(n_rx), np.arange(n_comp)), # Components + np.kron(np.arange(n_rx), np.ones(n_comp)), # Receivers + ].astype(int) + self.sorting = np.arange(n_rx, dtype=int) + return [sources] def assemble_keyword_arguments(self, **_): @@ -113,135 +113,45 @@ def assemble_keyword_arguments(self, **_): def build( self, data=None, - local_index=None, - indices=None, - channel=None, ): """Overloads base method to add dobs, std attributes to survey class instance.""" - survey = super().build( data=data, - local_index=local_index, - channel=channel, ) - + survey.n_channels = len( + data.normalizations + ) # Either time channels or frequencies + survey.n_components = len(data.components) if not self.params.forward_only: - self._add_data(survey, data, self.local_index, channel) + self._add_data(survey, data) survey.dummy = self.dummy - return survey, self.local_index, self.ordering - - def _get_local_data(self, data, channel, local_index): - local_data = {} - local_uncertainties = {} - - components = list(data.observed.keys()) - for comp in components: - comp_name = comp - if self.factory_type == "magnetotellurics": - comp_name = { - "zxx_real": "zyy_real", - "zxx_imag": "zyy_imag", - "zxy_real": "zyx_real", - "zxy_imag": "zyx_imag", - "zyx_real": "zxy_real", - "zyx_imag": "zxy_imag", - "zyy_real": "zxx_real", - "zyy_imag": "zxx_imag", - }[comp] - - key = "_".join([str(channel), str(comp_name)]) - local_data[key] = data.observed[comp][channel][local_index] - local_uncertainties[key] = data.uncertainties[comp][channel][local_index] - - return local_data, local_uncertainties - - def _add_data(self, survey, data, local_index, channel): - if isinstance(local_index, list): - local_index = np.hstack(local_index) - - if self.factory_type in ["fdem", "fdem 1d", "tdem", "tdem 1d"]: - dobs = [] - uncerts = [] - - data_stack = [np.vstack(list(k.values())) for k in data.observed.values()] - uncert_stack = [ - np.vstack(list(k.values())) for k in data.uncertainties.values() - ] - for order in self.ordering: - channel_id, component_id, rx_id = order - dobs.append(data_stack[component_id][channel_id, rx_id]) - uncerts.append(uncert_stack[component_id][channel_id, rx_id]) - - data_vec = np.vstack([dobs]).flatten() - uncertainty_vec = np.vstack([uncerts]).flatten() - - elif self.factory_type in ["magnetotellurics", "tipper"]: - local_data = {} - local_uncertainties = {} - - if channel is None: - channels = np.unique([list(v.keys()) for v in data.observed.values()]) - for chan in channels: - dat, unc = self._get_local_data(data, chan, local_index) - local_data.update(dat) - local_uncertainties.update(unc) - - else: - dat, unc = self._get_local_data(data, channel, local_index) - local_data.update(dat) - local_uncertainties.update(unc) - - data_vec = self._stack_channels(local_data, "row") - uncertainty_vec = self._stack_channels(local_uncertainties, "row") - - else: - local_data = {k: v[local_index] for k, v in data.observed.items()} - local_uncertainties = { - k: v[local_index] for k, v in data.uncertainties.items() - } - - data_vec = self._stack_channels(local_data, "column") - uncertainty_vec = self._stack_channels(local_uncertainties, "column") - - uncertainty_vec[np.isnan(data_vec)] = np.inf - data_vec[np.isnan(data_vec)] = self.dummy # Nan's handled by inf uncertainties - survey.dobs = data_vec - survey.std = uncertainty_vec - - def _stack_channels(self, channel_data: dict[str, np.ndarray], mode: str): - """ - Convert dictionary of data/uncertainties to stacked array. - - parameters: - ---------- - - channel_data: Array of data to stack - mode: Stacks rows or columns before flattening. Must be either 'row' or 'column'. - - - notes: - ------ - If mode is row the components will be clustered in the resulting 1D array. - Column stacking results in the locations being clustered. - - """ - if mode == "column": - return np.column_stack(list(channel_data.values())).ravel() - elif mode == "row": - return np.row_stack(list(channel_data.values())).ravel() - - def _dcip_arguments(self, data=None, local_index=None): + return survey + + def _add_data(self, survey, data): + # Stack the data by [channel, component, receiver] + data_stack = np.dstack( + [np.vstack(list(k.values())) for k in data.observed.values()] + ).transpose((0, 2, 1)) + uncert_stack = np.dstack( + [np.vstack(list(k.values())) for k in data.uncertainties.values()] + ).transpose((0, 2, 1)) + + uncert_stack[np.isnan(data_stack)] = np.inf + data_stack[np.isnan(data_stack)] = ( + self.dummy + ) # Nan's handled by inf uncertainties + survey.dobs = data_stack + survey.std = uncert_stack + + def _dcip_arguments(self, data=None): if getattr(data, "entity", None) is None: return None receiver_entity = data.entity - if "2d" in self.factory_type: - self.local_index = np.arange(receiver_entity.n_cells) - source_ids, order = np.unique( - receiver_entity.ab_cell_id.values[self.local_index], return_index=True + receiver_entity.ab_cell_id.values, return_index=True ) currents = receiver_entity.current_electrodes @@ -252,20 +162,17 @@ def _dcip_arguments(self, data=None, local_index=None): receiver_locations = receiver_entity.vertices source_locations = currents.vertices - # TODO hook up tile_spatial to handle local_index handling sources = [] - self.local_index = [] + sorting = [] for source_id in source_ids[np.argsort(order)]: # Cycle in original order receiver_indices = np.where(receiver_entity.ab_cell_id.values == source_id)[ 0 ] - if local_index is not None: - receiver_indices = list(set(receiver_indices).intersection(local_index)) - if len(receiver_indices) == 0: continue + sorting.append(receiver_indices) receivers = ReceiversFactory(self.params).build( locations=receiver_locations, local_index=receiver_entity.cells[receiver_indices], @@ -282,10 +189,15 @@ def _dcip_arguments(self, data=None, local_index=None): receivers=receivers, locations=source_locations[currents.cells[cell_ind].flatten()], ) - + source.rx_ids = np.asarray(receiver_indices) sources.append(source) - self.local_index.append(receiver_indices) + self.ordering = np.c_[ + np.zeros(receiver_entity.n_cells), # Single channel + np.zeros(receiver_entity.n_cells), # Single component + np.hstack(sorting), # Multi-receivers + ].astype(int) + self.sorting = np.hstack(sorting).astype(int) return [sources] def _tdem_arguments(self, data=None): @@ -306,12 +218,12 @@ def _tdem_arguments(self, data=None): "Transmitter ID property required for LargeLoopGroundTEMReceivers" ) - tx_rx = receivers.tx_id_property.values[self.local_index] + tx_rx = receivers.tx_id_property.values tx_ids = transmitters.tx_id_property.values - rx_lookup = [] + sorting = [] tx_locs = [] for tx_id in np.unique(tx_rx): - rx_lookup.append(self.local_index[tx_rx == tx_id]) + sorting.append(np.where(tx_rx == tx_id)[0]) tx_ind = tx_ids == tx_id loop_cells = transmitters.cells[ np.all(tx_ind[transmitters.cells], axis=1), : @@ -319,12 +231,14 @@ def _tdem_arguments(self, data=None): loop_ind = np.r_[loop_cells[:, 0], loop_cells[-1, 1]] tx_locs.append(transmitters.vertices[loop_ind, :]) else: - rx_lookup = self.local_index[:, np.newaxis].tolist() - tx_locs = [transmitters.vertices[k, :] for k in self.local_index] + # Assumes 1:1 mapping of tx to rx + sorting = np.arange(receivers.n_vertices).tolist() + tx_locs = transmitters.vertices wave_times = ( receivers.waveform[:, 0] - receivers.timing_mark ) * self.params.unit_conversion + if "1d" in self.factory_type: on_times = wave_times <= 0.0 waveform = tdem.sources.PiecewiseLinearWaveform( @@ -342,119 +256,137 @@ def _tdem_arguments(self, data=None): waveform_function=wave_function, offTime=0.0 ) - self.ordering = [] tx_list = [] rx_factory = ReceiversFactory(self.params) tx_factory = SourcesFactory(self.params) - for cur_tx_locs, rx_ids in zip(tx_locs, rx_lookup, strict=True): + ordering = [] + for cur_tx_locs, rx_ids in zip(tx_locs, sorting, strict=True): locs = receivers.vertices[rx_ids, :] - rx_list = [] - for component_id, component in enumerate(data.components): + + for comp_id, component in enumerate(data.components): rx_obj = rx_factory.build( locations=locs, - local_index=self.local_index, data=data, component=component, ) - rx_obj.local_index = rx_ids rx_list.append(rx_obj) + n_times = len(receivers.channels) + n_rx = len(rx_ids) if isinstance(rx_ids, np.ndarray) else 1 + ordering.append( + np.c_[ + np.kron(np.arange(n_times), np.ones(n_rx)), + np.ones(n_times * n_rx) * comp_id, + np.kron(np.ones(n_times), np.asarray(rx_ids)), + ] + ) - for time_id in range(len(receivers.channels)): - for rx_id in rx_ids: - self.ordering.append([time_id, component_id, rx_id]) - - tx_list.append( - tx_factory.build(rx_list, locations=cur_tx_locs, waveform=waveform) - ) + tx = tx_factory.build(rx_list, locations=cur_tx_locs, waveform=waveform) + tx.rx_ids = np.r_[rx_ids].astype(int) + tx_list.append(tx) + self.ordering = np.vstack(ordering).astype(int) + self.sorting = np.hstack(sorting).astype(int) return [tx_list] - def _fem_arguments(self, data=None, channel=None): + def _fem_arguments(self, data=None): channels = np.array(data.entity.channels) - frequencies = channels if channel is None else [channel] rx_locs = data.entity.vertices tx_locs = data.entity.transmitters.vertices - freqs = data.entity.transmitters.workspace.get_entity("Tx frequency")[0] - freqs = np.array([int(freqs.value_map[f]) for f in freqs.values]) + frequencies = data.entity.transmitters.workspace.get_entity("Tx frequency")[0] + frequencies = np.array( + [int(frequencies.value_map[f]) for f in frequencies.values] + ) - self.ordering = [] sources = [] rx_factory = ReceiversFactory(self.params) tx_factory = SourcesFactory(self.params) - - receiver_groups = {} - ordering = [] - for receiver_id in self.local_index: + receiver_groups = [] + block_ordering = [] + for rx_id, locs in enumerate(rx_locs): receivers = [] - for component_id, component in enumerate(data.components): + for comp_id, component in enumerate(data.components): receiver = rx_factory.build( - locations=rx_locs[receiver_id, :], + locations=locs, data=data, component=component, ) - - receiver.local_index = receiver_id + block_ordering.append([comp_id, rx_id]) receivers.append(receiver) - ordering.append([component_id, receiver_id]) - receiver_groups[receiver_id] = receivers - - ordering = np.vstack(ordering) - self.ordering = [] - for frequency in frequencies: - frequency_id = np.where(frequency == channels)[0][0] - self.ordering.append( - np.hstack([np.ones((ordering.shape[0], 1)) * frequency_id, ordering]) - ) - for receiver_id, receivers in receiver_groups.items(): - locs = tx_locs[frequency == freqs, :][receiver_id, :] - sources.append( - tx_factory.build( - receivers, - locations=locs, - frequency=frequency, - ) + receiver_groups.append(receivers) + + block_ordering = np.vstack(block_ordering) + ordering = [] + for freq_id, frequency in enumerate(channels): + for rx_id, receivers in enumerate(receiver_groups): + locs = tx_locs[frequency == frequencies, :][rx_id, :] + tx = tx_factory.build( + receivers, + locations=locs, + frequency=frequency, ) + tx.rx_ids = np.r_[rx_id] + sources.append(tx) - self.ordering = np.vstack(self.ordering).astype(int) + ordering.append( + np.hstack( + [ + np.ones((block_ordering.shape[0], 1)) * freq_id, + block_ordering, + ] + ) + ) + self.ordering = np.vstack(ordering).astype(int) + self.sorting = np.arange(rx_locs.shape[0], dtype=int) return [sources] - def _naturalsource_arguments(self, data=None, frequency=None): + def _naturalsource_arguments(self, data=None): + simpeg_mt_translate = { + "zxx_real": "zyy_real", + "zxx_imag": "zyy_imag", + "zxy_real": "zyx_real", + "zxy_imag": "zyx_imag", + "zyx_real": "zxy_real", + "zyx_imag": "zxy_imag", + "zyy_real": "zxx_real", + "zyy_imag": "zxx_imag", + } receivers = [] sources = [] rx_factory = ReceiversFactory(self.params) tx_factory = SourcesFactory(self.params) - ordering = [] - channels = np.array(data.entity.channels) - for component_id, comp in enumerate(data.components): + block_ordering = [] + self.sorting = np.arange(data.locations.shape[0], dtype=int) + for comp_id, comp in enumerate(data.components): receivers.append( rx_factory.build( locations=data.locations, - local_index=self.local_index, data=data, - component=comp, + component=simpeg_mt_translate.get(comp, comp), ) ) - ordering.append( - np.c_[np.ones_like(self.local_index) * component_id, self.local_index] + block_ordering.append( + np.c_[np.ones_like(self.sorting) * comp_id, self.sorting] ) - ordering = np.vstack(ordering) - self.ordering = [] - if frequency is None: - frequencies = channels - else: - frequencies = [frequency] if not isinstance(frequency, list) else frequency + block_ordering = np.vstack(block_ordering) + ordering = [] - for frequency in frequencies: - sources.append(tx_factory.build(receivers, frequency=frequency)) - frequency_id = np.where(frequency == channels)[0][0] - self.ordering.append( - np.hstack([np.ones((ordering.shape[0], 1)) * frequency_id, ordering]) + for freq_id, frequency in enumerate(data.entity.channels): + tx = tx_factory.build(receivers, frequency=frequency) + tx.rx_ids = np.arange(data.locations.shape[0], dtype=int) + sources.append(tx) + ordering.append( + np.hstack( + [ + np.ones((block_ordering.shape[0], 1)) * freq_id, + block_ordering, + ] + ) ) - self.ordering = np.vstack(self.ordering).astype(int) + self.ordering = np.vstack(ordering).astype(int) return [sources] diff --git a/simpeg_drivers/driver.py b/simpeg_drivers/driver.py index 4f3db2fc..4152ef1d 100644 --- a/simpeg_drivers/driver.py +++ b/simpeg_drivers/driver.py @@ -48,8 +48,9 @@ maps, objective_function, optimization, + simulation, ) - +from simpeg.potential_fields.base import BasePFSimulation from simpeg.regularization import ( BaseRegularization, RegularizationMesh, @@ -65,13 +66,17 @@ InversionTopography, InversionWindow, ) -from simpeg_drivers.components.factories import DirectivesFactory, MisfitFactory +from simpeg_drivers.components.factories import ( + DirectivesFactory, + MisfitFactory, + SimulationFactory, +) from simpeg_drivers.options import ( BaseForwardOptions, BaseInversionOptions, ) from simpeg_drivers.joint.options import BaseJointOptions -from simpeg_drivers.utils.utils import tile_locations +from simpeg_drivers.utils.nested import tile_locations from simpeg_drivers.utils.regularization import cell_neighbors, set_rotated_operators mlogger = logging.getLogger("distributed") @@ -100,6 +105,7 @@ def __init__(self, params: BaseForwardOptions | BaseInversionOptions): self._n_values: int | None = None self._optimization: optimization.ProjectedGNCG | None = None self._regularization: None = None + self._simulation: simulation.BaseSimulation | None = None self._sorting: list[np.ndarray] | None = None self._ordering: list[np.ndarray] | None = None self._mappings: list[maps.IdentityMap] | None = None @@ -163,17 +169,12 @@ def data_misfit(self): self.logger.write(f"Setting up {len(tiles)} tile(s) . . .\n") # Build tiled misfits and combine to form global misfit - self._data_misfit, self._sorting, self._ordering = MisfitFactory( - self.params, models=self.models - ).build( + self._data_misfit = MisfitFactory(self.params, self.simulation).build( tiles, self.split_list, - self.inversion_data, - self.inversion_mesh, - self.models.active_cells, ) self.logger.write("Saving data to file...\n") - + self._sorting = tiles if isinstance(self.params, BaseInversionOptions): self._data_misfit.multipliers = np.asarray( self._data_misfit.multipliers, dtype=float @@ -286,6 +287,13 @@ def models(self): return self._models + @property + def n_blocks(self): + """ + Number of model components in the inversion. + """ + return 3 if self.params.inversion_type == "magnetic vector" else 1 + @property def n_values(self): """Number of values in the model""" @@ -315,7 +323,7 @@ def optimization(self): @property def ordering(self): """List of ordering of the data.""" - return self._ordering + return self.inversion_data.survey.ordering @property def out_group(self): @@ -379,8 +387,28 @@ def regularization(self, regularization: objective_function.ComboObjectiveFuncti self._regularization = regularization @property - def sorting(self): - """List of arrays for sorting of data from tiles.""" + def simulation(self): + """ + The simulation object used in the inversion. + """ + if getattr(self, "_simulation", None) is None: + simulation_factory = SimulationFactory(self.params) + self._simulation = simulation_factory.build( + mesh=self.inversion_mesh.mesh, + models=self.models, + survey=self.inversion_data.survey, + ) + + if not hasattr(self._simulation, "active_cells"): + self._simulation.active_cells = self.models.active_cells + + return self._simulation + + @property + def sorting(self) -> list[np.ndarray] | None: + """ + Sorting of the data locations. + """ return self._sorting @property @@ -605,6 +633,7 @@ def get_tiles(self): self.inversion_data.locations, self.params.compute.tile_spatial, labels=self.inversion_data.parts, + sorting=self.simulation.survey.sorting, ) def configure_dask(self): diff --git a/simpeg_drivers/electromagnetics/base_1d_driver.py b/simpeg_drivers/electromagnetics/base_1d_driver.py index e5befd66..9aa2e1c1 100644 --- a/simpeg_drivers/electromagnetics/base_1d_driver.py +++ b/simpeg_drivers/electromagnetics/base_1d_driver.py @@ -21,7 +21,7 @@ from geoh5py.shared.merging.drape_model import DrapeModelMerger from geoh5py.ui_json.ui_json import fetch_active_workspace -from simpeg_drivers.components.factories import MisfitFactory +from simpeg_drivers.components.factories import MisfitFactory, SimulationFactory from simpeg_drivers.components.meshes import InversionMesh from simpeg_drivers.driver import InversionDriver from simpeg_drivers.utils.utils import topo_drape_elevation, xyz_2_drape_model @@ -69,7 +69,6 @@ def inversion_mesh(self) -> InversionMesh: self._inversion_mesh = InversionMesh( self.workspace, self.params, entity=entity ) - self._inversion_mesh.layers_mesh = self.layers_mesh return self._inversion_mesh @@ -87,37 +86,23 @@ def get_1d_mesh(self) -> TensorMesh: return layers_mesh @property - def data_misfit(self): - """The Simpeg.data_misfit class""" - if getattr(self, "_data_misfit", None) is None: - with fetch_active_workspace(self.workspace, mode="r+"): - # Tile locations - tiles = self.get_tiles() - - logger.info("Setting up %i tile(s) . . .", len(tiles)) - # Build tiled misfits and combine to form global misfit - self._data_misfit, self._sorting, self._ordering = MisfitFactory( - self.params, models=self.models - ).build( - tiles, - self.split_list, - self.inversion_data, - self.inversion_mesh, - self.topo_z_drape, - ) - self.models.active_cells = np.ones( - self.inversion_mesh.mesh.n_cells, dtype=bool - ) - logger.info("Done.") - - self._data_misfit.multipliers = np.asarray( - self._data_misfit.multipliers, dtype=float - ) + def simulation(self): + """ + The simulation object used in the inversion. + """ + if getattr(self, "_simulation", None) is None: + simulation_factory = SimulationFactory(self.params) + self._simulation = simulation_factory.build( + mesh=self.inversion_mesh.mesh, + models=self.models, + survey=self.inversion_data.survey, + ) - if self.client: - self.distributed_misfits() + self._simulation.mesh = self.inversion_mesh.mesh + self._simulation.layers_mesh = self.layers_mesh + self._simulation.active_cells = self.topo_z_drape - return self._data_misfit + return self._simulation @property def split_list(self): @@ -126,7 +111,4 @@ def split_list(self): """ n_misfits = self.inversion_data.mask.sum() - if isinstance(self.params.data_object, FEMSurvey): - n_misfits *= len(self.params.data_object.channels) - return [1] * n_misfits diff --git a/simpeg_drivers/electromagnetics/time_domain/driver.py b/simpeg_drivers/electromagnetics/time_domain/driver.py index c77c5e6c..e790e0c4 100644 --- a/simpeg_drivers/electromagnetics/time_domain/driver.py +++ b/simpeg_drivers/electromagnetics/time_domain/driver.py @@ -17,7 +17,6 @@ ) from simpeg_drivers.driver import InversionDriver -from simpeg_drivers.utils.utils import tile_locations from .options import ( TDEMForwardOptions, @@ -25,101 +24,15 @@ ) -def tile_large_group_transmitters( - survey: LargeLoopGroundTEMReceivers, n_tiles: int -) -> list[np.ndarray]: - """ - Tile the data based on the transmitters center locations. - - :param survey: LargeLoopGroundTEMReceivers object. - :param n_tiles: Number of tiles. - - :return: List of numpy arrays containing the indices of the receivers in each tile. - """ - if not isinstance(survey, LargeLoopGroundTEMReceivers): - raise TypeError("Data object must be of type LargeLoopGroundTEMReceivers") - - tx_ids = survey.transmitters.tx_id_property.values - unique_tile_ids = np.unique(tx_ids) - n_groups = np.min([len(unique_tile_ids), n_tiles]) - locations = [] - for uid in unique_tile_ids: - locations.append( - np.mean( - survey.transmitters.vertices[tx_ids == uid], - axis=0, - ) - ) - - # Tile transmitters spatially by loop center - tx_tiles = tile_locations( - np.vstack(locations), - n_groups, - ) - receivers_tx_ids = survey.tx_id_property.values - tiles = [] - for _t_id, group in enumerate(tx_tiles): - sub_group = [] - for value in group: - receiver_ind = receivers_tx_ids == unique_tile_ids[value] - sub_group.append(np.where(receiver_ind)[0]) - - tiles.append(np.hstack(sub_group)) - - # If number of tiles remaining, brake up receivers spatially per transmitter - while len(tiles) < n_tiles: - largest_group = np.argmax([len(tile) for tile in tiles]) - tile = tiles.pop(largest_group) - new_tiles = tile_locations( - survey.vertices[tile], - 2, - ) - tiles += [tile[new_tiles[0]], tile[new_tiles[1]]] - - return tiles - - class TDEMForwardDriver(InversionDriver): """Time Domain Electromagnetic forward driver.""" _options_class = TDEMForwardOptions _validations = None - def get_tiles(self) -> list[np.ndarray]: - """ - Special method to tile the data based on the transmitters center locations. - - First the transmitter locations are grouped into groups using kmeans clustering. - Second, if the number of groups is less than the number of 'tile_spatial' value, the groups are - further divided into groups based on the clustering of receiver locations. - """ - if not isinstance(self.params.data_object, LargeLoopGroundTEMReceivers): - return super().get_tiles() - - return tile_large_group_transmitters( - self.params.data_object, - self.params.compute.tile_spatial, - ) - class TDEMInversionDriver(InversionDriver): """Time Domain Electromagnetic inversion driver.""" _options_class = TDEMInversionOptions _validations = None - - def get_tiles(self) -> list[np.ndarray]: - """ - Special method to tile the data based on the transmitters center locations. - - First the transmitter locations are grouped into groups using kmeans clustering. - Second, if the number of groups is less than the number of 'tile_spatial' value, the groups are - further divided into groups based on the clustering of receiver locations. - """ - if not isinstance(self.params.data_object, LargeLoopGroundTEMReceivers): - return super().get_tiles() - - return tile_large_group_transmitters( - self.params.data_object, - self.params.compute.tile_spatial, - ) diff --git a/simpeg_drivers/electromagnetics/time_domain/options.py b/simpeg_drivers/electromagnetics/time_domain/options.py index 9ff341b6..6ec47162 100644 --- a/simpeg_drivers/electromagnetics/time_domain/options.py +++ b/simpeg_drivers/electromagnetics/time_domain/options.py @@ -14,6 +14,7 @@ from pathlib import Path from typing import ClassVar, TypeAlias +import numpy as np from geoh5py.groups import PropertyGroup from geoh5py.objects import ( AirborneTEMReceivers, @@ -56,6 +57,23 @@ def unit_conversion(self): } return conversion[self.data_object.unit] + @property + def timing_mark(self): + """ + Return the "zero time" mark of the TDEM data in the appropriate units. + """ + return self.data_object.timing_mark * self.unit_conversion + + @property + def time_steps(self): + """ + Return the time steps of the TDEM data in the appropriate units. + """ + return ( + np.round((np.diff(np.unique(self.data_object.waveform[:, 0]))), decimals=6) + * self.unit_conversion + ) + class TDEMForwardOptions(BaseTDEMOptions, BaseForwardOptions): """ diff --git a/simpeg_drivers/joint/driver.py b/simpeg_drivers/joint/driver.py index 16f162ac..6f698fc1 100644 --- a/simpeg_drivers/joint/driver.py +++ b/simpeg_drivers/joint/driver.py @@ -123,7 +123,7 @@ def initialize(self): global_actives, driver.inversion_mesh.mesh, enforce_active=False, - components=driver.inversion_data.n_blocks, + components=driver.n_blocks, ) driver.params.active_model = None driver.models.active_cells = projection.local_active @@ -211,7 +211,7 @@ def n_values(self): n_values = self.models.n_active count = [] for driver in self.drivers: - n_comp = driver.inversion_data.n_blocks # If vector of scalar model + n_comp = driver.n_blocks # If vector of scalar model count.append(n_values * n_comp) self._n_values = count @@ -239,8 +239,6 @@ def run(self): for sub, driver in zip(predicted, self.drivers, strict=True): SaveDataGeoh5Factory(driver.params).build( inversion_object=driver.inversion_data, - sorting=np.argsort(np.hstack(driver.sorting)), - ordering=driver.ordering, ).write(0, sub) else: # Run the inversion diff --git a/simpeg_drivers/options.py b/simpeg_drivers/options.py index efc8f0f6..42c9e2e8 100644 --- a/simpeg_drivers/options.py +++ b/simpeg_drivers/options.py @@ -466,21 +466,13 @@ def property_group_data(self, property_group: PropertyGroup): if property_group is None: return dict.fromkeys(frequencies) - data = {} group = next( k for k in self.data_object.property_groups if k.uid == property_group.uid ) - property_names = [self.geoh5.get_entity(p)[0].name for p in group.properties] - properties = [self.geoh5.get_entity(p)[0].values for p in group.properties] - for i, f in enumerate(frequencies): - try: - f_ind = property_names.index( - next(k for k in property_names if f"{f:.2e}" in k) - ) # Safer if data was saved with geoapps naming convention - data[f] = properties[f_ind] - except StopIteration: - data[f] = properties[i] # in case of other naming conventions - + data = { + freq: self.geoh5.get_entity(p)[0].values + for freq, p in zip(frequencies, group.properties, strict=False) + } return data @@ -633,7 +625,7 @@ def component_data(self, component: str) -> np.ndarray | None: data = getattr(self, "_".join([component, "channel"]), None) if isinstance(data, NumericData): data = data.values - return data + return {None: data} def component_uncertainty(self, component: str) -> np.ndarray | None: """ @@ -648,6 +640,6 @@ def component_uncertainty(self, component: str) -> np.ndarray | None: if isinstance(data, NumericData): data = data.values elif isinstance(data, float): - data *= np.ones_like(self.component_data(component)) + data *= np.ones_like(self.component_data(component)[None]) - return data + return {None: data} diff --git a/simpeg_drivers/utils/nested.py b/simpeg_drivers/utils/nested.py new file mode 100644 index 00000000..d50dced8 --- /dev/null +++ b/simpeg_drivers/utils/nested.py @@ -0,0 +1,436 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2025 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +from __future__ import annotations + +import warnings +from copy import copy +from pathlib import Path + +import numpy as np +from discretize import TensorMesh, TreeMesh +from scipy.optimize import linear_sum_assignment +from scipy.spatial import cKDTree +from scipy.spatial.distance import cdist +from simpeg import data, data_misfit, maps, meta +from simpeg.electromagnetics.base_1d import BaseEM1DSimulation +from simpeg.electromagnetics.frequency_domain.simulation import BaseFDEMSimulation +from simpeg.electromagnetics.frequency_domain.sources import ( + LineCurrent as FEMLineCurrent, +) +from simpeg.electromagnetics.natural_source import Simulation3DPrimarySecondary +from simpeg.electromagnetics.static.induced_polarization.simulation import ( + Simulation3DNodal as Simulation3DIP, +) +from simpeg.electromagnetics.static.resistivity.simulation import ( + Simulation3DNodal as Simulation3DRes, +) +from simpeg.electromagnetics.time_domain.simulation import BaseTDEMSimulation +from simpeg.electromagnetics.time_domain.sources import LineCurrent as TEMLineCurrent +from simpeg.simulation import BaseSimulation +from simpeg.survey import BaseSurvey + +from simpeg_drivers.utils.surveys import ( + compute_dc_projections, + compute_em_projections, + get_intersecting_cells, + get_unique_locations, +) + + +def create_mesh( + survey: BaseSurvey, + base_mesh: TreeMesh | TensorMesh, + padding_cells: int = 8, + minimum_level: int = 4, + finalize: bool = True, +) -> TreeMesh | TensorMesh: + """ + Create a nested mesh with the same extent as the input global mesh. + Refinement levels are preserved only around the input locations (local survey). + + + :param survey: SimPEG survey object. + :param base_mesh: Input global TreeMesh object. + :param padding_cells: Used for 'method'= 'padding_cells'. Number of cells in each concentric shell. + :param minimum_level: Minimum octree level to preserve everywhere outside the local survey area. + :param finalize: Return a finalized local treemesh. + + :return: A TreeMesh object with the same extent as the input global mesh. + """ + if not isinstance(base_mesh, TreeMesh): + return base_mesh + + locations = get_unique_locations(survey) + nested_mesh = TreeMesh( + [base_mesh.h[0], base_mesh.h[1], base_mesh.h[2]], + x0=base_mesh.x0, + diagonal_balance=False, + ) + base_level = base_mesh.max_level - minimum_level + base_refinement = base_mesh.cell_levels_by_index(np.arange(base_mesh.nC)) + base_refinement[base_refinement > base_level] = base_level + nested_mesh.insert_cells( + base_mesh.gridCC, + base_refinement, + finalize=False, + ) + base_cell = np.min([base_mesh.h[0][0], base_mesh.h[1][0]]) + tx_loops = [] + for source in survey.source_list: + if isinstance(source, TEMLineCurrent | FEMLineCurrent): + mesh_indices = get_intersecting_cells(source.location, base_mesh) + tx_loops.append(base_mesh.cell_centers[mesh_indices, :]) + + if tx_loops: + locations = np.vstack([locations, *tx_loops]) + + tree = cKDTree(locations[:, :2]) + rad, _ = tree.query(base_mesh.gridCC[:, :2]) + pad_distance = 0.0 + for ii in range(minimum_level): + pad_distance += base_cell * 2**ii * padding_cells + indices = np.where(rad < pad_distance)[0] + levels = base_mesh.cell_levels_by_index(indices) + levels[levels > (base_mesh.max_level - ii)] = base_mesh.max_level - ii + nested_mesh.insert_cells( + base_mesh.gridCC[indices, :], + levels, + finalize=False, + ) + + if finalize: + nested_mesh.finalize() + + return nested_mesh + + +def create_misfit( + simulation, + local_indices, + channel, + tile_count, + n_split, + padding_cells, + inversion_type, + forward_only, +): + """ + Create a list of local misfits based on the local indices. + + The local indices are further split into smaller chunks if requested, sharing + the same mesh. + + :param simulation: SimPEG simulation object. + :param local_indices: Indices of the receiver locations belonging to the tile. + :param channel: Channel of the simulationm, for frequency systems only. + :param tile_count: Current tile ID, used to name the file on disk and for sampling + of topography for 1D simulations. + :param n_split: Number of splits to create for the local indices. + :param padding_cells: Number of padding cells around the local survey. + :param inversion_type: Type of inversion, used to name the misfit (joint inversion). + :param forward_only: If False, data is transferred to the local simulation. + + :return: List of local misfits and data slices. + """ + local_sim, _, _ = create_simulation( + simulation, + None, + local_indices, + channel=channel, + tile_id=tile_count, + padding_cells=padding_cells, + ) + + local_mesh = getattr(local_sim, "mesh", None) + local_misfits = [] + data_slices = [] + for split_ind in np.array_split(local_indices, n_split): + local_sim, mapping, data_slice = create_simulation( + simulation, + local_mesh, + split_ind, + channel=channel, + tile_id=tile_count, + padding_cells=padding_cells, + ) + meta_simulation = meta.MetaSimulation( + simulations=[local_sim], mappings=[mapping] + ) + + local_data = data.Data(local_sim.survey) + lmisfit = data_misfit.L2DataMisfit(local_data, meta_simulation) + if not forward_only: + local_data.dobs = local_sim.survey.dobs + local_data.standard_deviation = local_sim.survey.std + name = inversion_type + name += f": Tile {tile_count + 1}" + if channel is not None: + name += f": Channel {channel}" + + lmisfit.name = f"{name}" + + local_misfits.append(lmisfit) + data_slices.append(data_slice) + + tile_count += 1 + + return local_misfits, data_slices + + +def create_simulation( + simulation: BaseSimulation, + local_mesh: TreeMesh | None, + indices: np.ndarray, + *, + channel: int | None = None, + tile_id: int | None = None, + padding_cells=100, +): + """ + Generate a survey, mesh and simulation based on indices. + + :param simulation: SimPEG.simulation object. + :param local_mesh: Local mesh for the simulation, else created. + :param indices: Indices of receivers belonging to the tile. + :param channel: Channel of the simulation, for frequency simulations only. + :param tile_id: Tile id stored on the simulation. + :param padding_cells: Number of padding cells around the local survey. + + :return: Local simulation, mapping and local ordering. + """ + local_survey, local_ordering = create_survey( + simulation.survey, indices=indices, channel=channel + ) + kwargs = {"survey": local_survey} + + if local_mesh is None: + local_mesh = create_mesh( + local_survey, + simulation.mesh, + minimum_level=3, + padding_cells=padding_cells, + ) + + args = (local_mesh,) + if isinstance(simulation, BaseEM1DSimulation): + local_mesh = simulation.layers_mesh + actives = np.ones(simulation.layers_mesh.n_cells, dtype=bool) + model_slice = np.arange( + tile_id, simulation.mesh.n_cells, simulation.mesh.shape_cells[0] + )[::-1] + mapping = maps.Projection(simulation.mesh.n_cells, model_slice) + kwargs["topo"] = simulation.active_cells[tile_id] + args = () + + elif isinstance(local_mesh, TreeMesh): + mapping = maps.TileMap( + simulation.mesh, + simulation.active_cells, + local_mesh, + enforce_active=True, + components=3 if getattr(simulation, "model_type", None) == "vector" else 1, + ) + actives = mapping.local_active + # For DCIP-2D + else: + actives = simulation.active_cells + mapping = maps.IdentityMap(nP=int(actives.sum())) + + n_actives = int(actives.sum()) + if getattr(simulation, "_chiMap", None) is not None: + if simulation.model_type == "vector": + kwargs["chiMap"] = maps.IdentityMap(nP=n_actives * 3) + kwargs["model_type"] = "vector" + else: + kwargs["chiMap"] = maps.IdentityMap(nP=n_actives) + + kwargs["active_cells"] = actives + kwargs["sensitivity_path"] = ( + Path(simulation.sensitivity_path).parent / f"Tile{tile_id}.zarr" + ) + + if getattr(simulation, "_rhoMap", None) is not None: + kwargs["rhoMap"] = maps.IdentityMap(nP=n_actives) + kwargs["active_cells"] = actives + kwargs["sensitivity_path"] = ( + Path(simulation.sensitivity_path).parent / f"Tile{tile_id}.zarr" + ) + + if getattr(simulation, "_sigmaMap", None) is not None: + kwargs["sigmaMap"] = maps.ExpMap(local_mesh) * maps.InjectActiveCells( + local_mesh, actives, value_inactive=np.log(1e-8) + ) + + if getattr(simulation, "_etaMap", None) is not None: + kwargs["etaMap"] = maps.InjectActiveCells(local_mesh, actives, value_inactive=0) + proj = maps.InjectActiveCells( + local_mesh, + actives, + value_inactive=1e-8, + ) + kwargs["sigma"] = proj * mapping * simulation.sigma[simulation.active_cells] + + for key in [ + "max_chunk_sizestore_sensitivities", + "solver", + "t0", + "time_steps", + "thicknesses", + ]: + if hasattr(simulation, key): + kwargs[key] = getattr(simulation, key) + + local_sim = type(simulation)(*args, **kwargs) + + if isinstance( + simulation, BaseFDEMSimulation | BaseTDEMSimulation + ) and not isinstance(simulation, Simulation3DPrimarySecondary): + compute_em_projections(simulation.survey.locations, local_sim) + elif isinstance(simulation, Simulation3DRes | Simulation3DIP): + compute_dc_projections( + simulation.survey.locations, simulation.survey.cells, local_sim + ) + return local_sim, mapping, local_ordering + + +def create_survey(survey, indices, channel=None): + """ + Extract source and receivers belonging to the indices. + + :param survey: SimPEG survey object. + :param indices: Indices of the receivers belonging to the tile. + :param channel: Channel of the survey, for frequency systems only. + """ + sources = [] + + # Return the subset of data that belongs to the tile + slice_inds = np.isin(survey.ordering[:, 2], indices) + if channel is not None: + ind = np.where(np.asarray(survey.frequencies) == channel)[0] + slice_inds *= np.isin(survey.ordering[:, 0], ind) + + for src in survey.source_list or [survey.source_field]: + if channel is not None and getattr(src, "frequency", None) != channel: + continue + + # Extract the indices of the receivers that belong to this source + _, intersect, _ = np.intersect1d(src.rx_ids, indices, return_indices=True) + + if len(intersect) == 0: + continue + + receivers = [] + for rx in src.receiver_list: + new_rx = copy(rx) + + # For MT and DC surveys with multiple locations per receiver + if isinstance(rx.locations, tuple | list): + new_rx.locations = tuple(loc[intersect] for loc in rx.locations) + else: + new_rx.locations = rx.locations[intersect] + + receivers.append(new_rx) + + if any(receivers): + new_src = copy(src) + new_src.rx_ids = src.rx_ids[intersect] + new_src.receiver_list = receivers + sources.append(new_src) + + if hasattr(survey, "source_field"): + new_survey = type(survey)(sources[0]) + else: + new_survey = type(survey)(sources) + + if hasattr(survey, "dobs") and survey.dobs is not None: + # For FEM surveys only + new_survey.dobs = survey.dobs[ + survey.ordering[slice_inds, 0], + survey.ordering[slice_inds, 1], + survey.ordering[slice_inds, 2], + ] + new_survey.std = survey.std[ + survey.ordering[slice_inds, 0], + survey.ordering[slice_inds, 1], + survey.ordering[slice_inds, 2], + ] + + return new_survey, survey.ordering[slice_inds, :] + + +def tile_locations( + locations: np.ndarray, + n_tiles: int, + labels: np.ndarray | None = None, + sorting: np.ndarray | None = None, +) -> list[np.ndarray]: + """ + Function to tile a survey points into smaller square subsets of points using + a k-means clustering approach. + + If labels are provided and the number of unique labels is less than or equal to + the number of tiles, the function will return an even split of the unique labels. + + :param locations: Array of locations. + :param n_tiles: Number of tiles (for 'cluster') + :param labels: Array of values to append to the locations + :param sorting: Array of indices to sort the locations before clustering. + + :return: List of arrays containing the indices of the points in each tile. + """ + grid_locs = locations[:, :2].copy() + + if labels is not None: + if len(labels) != grid_locs.shape[0]: + raise ValueError( + "Labels array must have the same length as the locations array." + ) + + if len(np.unique(labels)) >= n_tiles: + label_groups = np.array_split(np.unique(labels), n_tiles) + return [np.where(np.isin(labels, group))[0] for group in label_groups] + + # Normalize location coordinates to [0, 1] range + grid_locs -= grid_locs.min(axis=0) + max_val = grid_locs.max(axis=0) + grid_locs[:, max_val > 0] /= max_val[max_val > 0] + grid_locs = np.c_[grid_locs, labels] + + if sorting is not None: + grid_locs = grid_locs[sorting, :] + + # Cluster + # TODO turn off filter once sklearn has dealt with the issue causing the warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + from sklearn.cluster import KMeans + + kmeans = KMeans(n_clusters=n_tiles, random_state=0, n_init="auto") + cluster_size = int(np.ceil(grid_locs.shape[0] / n_tiles)) + kmeans.fit(grid_locs) + + if labels is not None: + cluster_id = kmeans.labels_ + else: + # Redistribute cluster centers to even out the number of points + centers = kmeans.cluster_centers_ + centers = ( + centers.reshape(-1, 1, grid_locs.shape[1]) + .repeat(cluster_size, 1) + .reshape(-1, grid_locs.shape[1]) + ) + distance_matrix = cdist(grid_locs, centers) + cluster_id = linear_sum_assignment(distance_matrix)[1] // cluster_size + + tiles = [] + for tid in set(cluster_id): + tiles += [np.where(cluster_id == tid)[0]] + + return tiles diff --git a/simpeg_drivers/utils/surveys.py b/simpeg_drivers/utils/surveys.py index 4daebdf4..a15e0d4c 100644 --- a/simpeg_drivers/utils/surveys.py +++ b/simpeg_drivers/utils/surveys.py @@ -143,3 +143,42 @@ def get_unique_locations(survey: BaseSurvey) -> np.ndarray: locations = survey.receiver_locations return np.unique(locations, axis=0) + + +def compute_em_projections(locations, simulation): + """ + Pre-compute projections for the receivers for efficiency. + """ + projections = {} + for component in "xyz": + projections[component] = simulation.mesh.get_interpolation_matrix( + locations, "faces_" + component[0] + ) + + for source in simulation.survey.source_list: + indices = source.rx_ids + for receiver in source.receiver_list: + projection = 0.0 + for orientation, comp in zip(receiver.orientation, "xyz", strict=True): + if orientation == 0: + continue + projection += orientation * projections[comp][indices, :] + receiver.spatialP = projection + + +def compute_dc_projections(locations, cells, simulation): + """ + Pre-compute projections for the receivers for efficiency. + """ + projection = simulation.mesh.get_interpolation_matrix(locations, "nodes") + + for source in simulation.survey.source_list: + indices = source.rx_ids + for receiver in source.receiver_list: + proj_mn = projection[cells[indices, 0], :] + + # Check if dipole receiver + if not np.all(cells[indices, 0] == cells[indices, 1]): + proj_mn -= projection[cells[indices, 1], :] + + receiver.spatialP = proj_mn # pylint: disable=protected-access diff --git a/simpeg_drivers/utils/tile_estimate.py b/simpeg_drivers/utils/tile_estimate.py index 59becb23..bd245a91 100644 --- a/simpeg_drivers/utils/tile_estimate.py +++ b/simpeg_drivers/utils/tile_estimate.py @@ -38,10 +38,10 @@ from simpeg_drivers.components.data import InversionData from simpeg_drivers.components.factories.misfit_factory import MisfitFactory from simpeg_drivers.driver import InversionDriver +from simpeg_drivers.utils.nested import create_simulation, tile_locations from simpeg_drivers.utils.utils import ( active_from_xyz, simpeg_group_to_driver, - tile_locations, ) @@ -100,11 +100,9 @@ def get_results(self, max_tiles: int = 13) -> dict: # Get the median tile ind = int(np.argsort([len(tile) for tile in tiles])[int(count / 2)]) self.driver.params.compute.tile_spatial = int(count) - sim, _, _, mapping = MisfitFactory.create_nested_simulation( - self.driver.inversion_data, - self.driver.inversion_mesh, + sim, mapping, _ = create_simulation( + self.driver.simulation, None, - self.active_cells, tiles[ind], tile_id=ind, padding_cells=self.driver.params.padding_cells, diff --git a/simpeg_drivers/utils/utils.py b/simpeg_drivers/utils/utils.py index 56f22591..bda79ea3 100644 --- a/simpeg_drivers/utils/utils.py +++ b/simpeg_drivers/utils/utils.py @@ -11,7 +11,6 @@ from __future__ import annotations -import warnings from copy import deepcopy from typing import TYPE_CHECKING @@ -30,20 +29,11 @@ from geoh5py.ui_json import InputFile from octree_creation_app.utils import octree_2_treemesh from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator, interp1d -from scipy.optimize import linear_sum_assignment from scipy.spatial import ConvexHull, Delaunay, cKDTree -from scipy.spatial.distance import cdist -from simpeg.electromagnetics.frequency_domain.sources import ( - LineCurrent as FEMLineCurrent, -) -from simpeg.electromagnetics.time_domain.sources import LineCurrent as TEMLineCurrent -from simpeg.survey import BaseSurvey from simpeg_drivers import DRIVER_MAP from simpeg_drivers.utils.surveys import ( compute_alongline_distance, - get_intersecting_cells, - get_unique_locations, ) @@ -140,70 +130,6 @@ def calculate_2D_trend( return data_trend, params -def create_nested_mesh( - survey: BaseSurvey, - base_mesh: TreeMesh, - padding_cells: int = 8, - minimum_level: int = 4, - finalize: bool = True, -): - """ - Create a nested mesh with the same extent as the input global mesh. - Refinement levels are preserved only around the input locations (local survey). - - Parameters - ---------- - - locations: Array of coordinates for the local survey shape(*, 3). - base_mesh: Input global TreeMesh object. - padding_cells: Used for 'method'= 'padding_cells'. Number of cells in each concentric shell. - minimum_level: Minimum octree level to preserve everywhere outside the local survey area. - finalize: Return a finalized local treemesh. - """ - locations = get_unique_locations(survey) - nested_mesh = TreeMesh( - [base_mesh.h[0], base_mesh.h[1], base_mesh.h[2]], - x0=base_mesh.x0, - diagonal_balance=False, - ) - base_level = base_mesh.max_level - minimum_level - base_refinement = base_mesh.cell_levels_by_index(np.arange(base_mesh.nC)) - base_refinement[base_refinement > base_level] = base_level - nested_mesh.insert_cells( - base_mesh.gridCC, - base_refinement, - finalize=False, - ) - base_cell = np.min([base_mesh.h[0][0], base_mesh.h[1][0]]) - tx_loops = [] - for source in survey.source_list: - if isinstance(source, TEMLineCurrent | FEMLineCurrent): - mesh_indices = get_intersecting_cells(source.location, base_mesh) - tx_loops.append(base_mesh.cell_centers[mesh_indices, :]) - - if tx_loops: - locations = np.vstack([locations, *tx_loops]) - - tree = cKDTree(locations[:, :2]) - rad, _ = tree.query(base_mesh.gridCC[:, :2]) - pad_distance = 0.0 - for ii in range(minimum_level): - pad_distance += base_cell * 2**ii * padding_cells - indices = np.where(rad < pad_distance)[0] - levels = base_mesh.cell_levels_by_index(indices) - levels[levels > (base_mesh.max_level - ii)] = base_mesh.max_level - ii - nested_mesh.insert_cells( - base_mesh.gridCC[indices, :], - levels, - finalize=False, - ) - - if finalize: - nested_mesh.finalize() - - return nested_mesh - - def drape_to_octree( octree: Octree, drape_model: DrapeModel | list[DrapeModel], @@ -480,72 +406,6 @@ def xyz_2_drape_model( return model -def tile_locations( - locations: np.ndarray, - n_tiles: int, - labels: np.ndarray | None = None, -) -> list[np.ndarray]: - """ - Function to tile a survey points into smaller square subsets of points using - a k-means clustering approach. - - If labels are provided and the number of unique labels is less than or equal to - the number of tiles, the function will return an even split of the unique labels. - - :param locations: Array of locations. - :param n_tiles: Number of tiles (for 'cluster') - :param labels: Array of values to append to the locations - - :return: List of arrays containing the indices of the points in each tile. - """ - grid_locs = locations[:, :2].copy() - - if labels is not None: - if len(labels) != grid_locs.shape[0]: - raise ValueError( - "Labels array must have the same length as the locations array." - ) - - if len(np.unique(labels)) >= n_tiles: - label_groups = np.array_split(np.unique(labels), n_tiles) - return [np.where(np.isin(labels, group))[0] for group in label_groups] - - # Normalize location coordinates to [0, 1] range - grid_locs -= grid_locs.min(axis=0) - max_val = grid_locs.max(axis=0) - grid_locs[:, max_val > 0] /= max_val[max_val > 0] - grid_locs = np.c_[grid_locs, labels] - - # Cluster - # TODO turn off filter once sklearn has dealt with the issue causing the warning - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) - from sklearn.cluster import KMeans - - kmeans = KMeans(n_clusters=n_tiles, random_state=0, n_init="auto") - cluster_size = int(np.ceil(grid_locs.shape[0] / n_tiles)) - kmeans.fit(grid_locs) - - if labels is not None: - cluster_id = kmeans.labels_ - else: - # Redistribute cluster centers to even out the number of points - centers = kmeans.cluster_centers_ - centers = ( - centers.reshape(-1, 1, grid_locs.shape[1]) - .repeat(cluster_size, 1) - .reshape(-1, grid_locs.shape[1]) - ) - distance_matrix = cdist(grid_locs, centers) - cluster_id = linear_sum_assignment(distance_matrix)[1] // cluster_size - - tiles = [] - for tid in set(cluster_id): - tiles += [np.where(cluster_id == tid)[0]] - - return tiles - - def get_containing_cells( mesh: TreeMesh | TensorMesh, data: InversionData ) -> np.ndarray: diff --git a/tests/data_test.py b/tests/data_test.py index d3aaa81e..f3606442 100644 --- a/tests/data_test.py +++ b/tests/data_test.py @@ -159,19 +159,22 @@ def test_survey_data(tmp_path: Path): # test locations np.testing.assert_array_equal( - verts[driver.sorting[0], :2], local_survey_a.receiver_locations[:, :2] - ) - np.testing.assert_array_equal( - verts[driver.sorting[1], :2], local_survey_b.receiver_locations[:, :2] + verts[np.hstack(driver.sorting), :2], + np.vstack( + [ + local_survey_a.receiver_locations[:, :2], + local_survey_b.receiver_locations[:, :2], + ] + ), ) + assert all(local_survey_a.receiver_locations[:, 2] == 0.0) assert all(local_survey_b.receiver_locations[:, 2] == 0.0) # test observed data - sorting = np.hstack(driver.sorting) expected_dobs = np.column_stack( [bxx_data.values, byy_data.values, bzz_data.values] - )[sorting].ravel() + )[np.hstack(driver.sorting)].ravel() survey_dobs = [local_survey_a.dobs, local_survey_b.dobs] np.testing.assert_array_equal(expected_dobs, np.hstack(survey_dobs)) @@ -221,9 +224,9 @@ def test_get_uncertainty_component(tmp_path: Path): with geoh5.open(): data = InversionData(geoh5, params) unc = params.uncertainties["tmi"] - assert len(np.unique(unc)) == 1 - assert np.unique(unc)[0] == 1 - assert len(unc) == data.entity.n_vertices + assert len(unc) == 1 + assert np.unique(unc[None])[0] == 1 + assert len(unc[None]) == data.entity.n_vertices def test_normalize(tmp_path: Path): @@ -233,7 +236,7 @@ def test_normalize(tmp_path: Path): data = InversionData(geoh5, params) data.normalizations = data.get_normalizations() test_data = data.normalize(data.observed) - assert all(test_data["tmi"] == params.data["tmi"]) + assert all(test_data["tmi"][None] == params.data["tmi"][None]) assert len(test_data) == 1 @@ -243,7 +246,7 @@ def test_get_survey(tmp_path: Path): with geoh5.open(): data = InversionData(geoh5, params) survey = data.create_survey() - assert isinstance(survey[0], simpeg.potential_fields.magnetics.Survey) + assert isinstance(survey, simpeg.potential_fields.magnetics.Survey) def test_data_parts(tmp_path: Path): diff --git a/tests/locations_test.py b/tests/locations_test.py index efdd6230..3b57951f 100644 --- a/tests/locations_test.py +++ b/tests/locations_test.py @@ -19,6 +19,7 @@ from simpeg_drivers.components.locations import InversionLocations from simpeg_drivers.potential_fields import MVIInversionOptions +from simpeg_drivers.utils.nested import tile_locations from simpeg_drivers.utils.synthetics.driver import SyntheticsComponents from simpeg_drivers.utils.synthetics.options import ( MeshOptions, @@ -26,7 +27,6 @@ SurveyOptions, SyntheticsComponentsOptions, ) -from simpeg_drivers.utils.utils import tile_locations def get_mvi_params(tmp_path: Path) -> MVIInversionOptions: diff --git a/tests/run_tests/driver_2d_rotated_gradients_test.py b/tests/run_tests/driver_2d_rotated_gradients_test.py index affd14d7..2107cdee 100644 --- a/tests/run_tests/driver_2d_rotated_gradients_test.py +++ b/tests/run_tests/driver_2d_rotated_gradients_test.py @@ -115,7 +115,7 @@ def test_dc2d_rotated_grad_run( ) with Workspace(workpath) as geoh5: - potential = geoh5.get_entity("Iteration_0_dc")[0] + potential = geoh5.get_entity("Iteration_0_potential")[0] components = SyntheticsComponents(geoh5) orig_potential = potential.values.copy() diff --git a/tests/run_tests/driver_dc_2d_test.py b/tests/run_tests/driver_dc_2d_test.py index 2a4dc370..6e32ec8f 100644 --- a/tests/run_tests/driver_dc_2d_test.py +++ b/tests/run_tests/driver_dc_2d_test.py @@ -94,7 +94,7 @@ def test_dc_2d_run(tmp_path: Path, max_iterations=1, pytest=True): workpath = tmp_path.parent / "test_dc_2d_fwr_run0" / "inversion_test.ui.geoh5" with Workspace(workpath) as geoh5: - potential = geoh5.get_entity("Iteration_0_dc")[0] + potential = geoh5.get_entity("Iteration_0_potential")[0] components = SyntheticsComponents(geoh5) # Run the inverse diff --git a/tests/run_tests/driver_dc_b2d_rotated_gradients_test.py b/tests/run_tests/driver_dc_b2d_rotated_gradients_test.py index 259d54ea..0318d766 100644 --- a/tests/run_tests/driver_dc_b2d_rotated_gradients_test.py +++ b/tests/run_tests/driver_dc_b2d_rotated_gradients_test.py @@ -112,7 +112,7 @@ def test_dc_rotated_gradient_p3d_run( with Workspace(workpath) as geoh5: components = SyntheticsComponents(geoh5) - potential = geoh5.get_entity("Iteration_0_dc")[0] + potential = geoh5.get_entity("Iteration_0_potential")[0] # Create property group with orientation dip = np.ones(components.mesh.n_cells) * 45 diff --git a/tests/run_tests/driver_dc_b2d_test.py b/tests/run_tests/driver_dc_b2d_test.py index 9e7ffe83..99027fd0 100644 --- a/tests/run_tests/driver_dc_b2d_test.py +++ b/tests/run_tests/driver_dc_b2d_test.py @@ -101,7 +101,7 @@ def test_dc_p3d_run( with Workspace(workpath) as geoh5: components = SyntheticsComponents(geoh5) - potential = geoh5.get_entity("Iteration_0_dc")[0] + potential = geoh5.get_entity("Iteration_0_potential")[0] # Run the inverse params = DCBatch2DInversionOptions.build( diff --git a/tests/run_tests/driver_dc_test.py b/tests/run_tests/driver_dc_test.py index a4ec28a3..7db008fa 100644 --- a/tests/run_tests/driver_dc_test.py +++ b/tests/run_tests/driver_dc_test.py @@ -99,7 +99,7 @@ def test_dc_3d_run( with Workspace(workpath) as geoh5: components = SyntheticsComponents(geoh5) - potential = geoh5.get_entity("Iteration_0_dc")[0] + potential = geoh5.get_entity("Iteration_0_potential")[0] # Run the inverse params = DC3DInversionOptions.build( diff --git a/tests/run_tests/driver_ground_tem_test.py b/tests/run_tests/driver_ground_tem_test.py index 9eedbf08..9437539d 100644 --- a/tests/run_tests/driver_ground_tem_test.py +++ b/tests/run_tests/driver_ground_tem_test.py @@ -90,10 +90,12 @@ def test_tiling_ground_tem( y_channel_bool=True, z_channel_bool=True, tile_spatial=4, + solver_type="Mumps", ) fwr_driver = TDEMForwardDriver(params) - tiles = fwr_driver.get_tiles() + with geoh5.open(): + tiles = fwr_driver.get_tiles() assert len(tiles) == 4 diff --git a/tests/run_tests/driver_ip_2d_test.py b/tests/run_tests/driver_ip_2d_test.py index 6ca51ffb..b70c55b8 100644 --- a/tests/run_tests/driver_ip_2d_test.py +++ b/tests/run_tests/driver_ip_2d_test.py @@ -89,7 +89,7 @@ def test_ip_2d_run( with Workspace(workpath) as geoh5: components = SyntheticsComponents(geoh5) - chargeability = geoh5.get_entity("Iteration_0_ip")[0] + chargeability = geoh5.get_entity("Iteration_0_chargeability")[0] # Run the inverse params = IP2DInversionOptions.build( diff --git a/tests/run_tests/driver_ip_b2d_test.py b/tests/run_tests/driver_ip_b2d_test.py index 52082182..62d00cf6 100644 --- a/tests/run_tests/driver_ip_b2d_test.py +++ b/tests/run_tests/driver_ip_b2d_test.py @@ -106,7 +106,7 @@ def test_ip_p3d_run( with Workspace(workpath) as geoh5: components = SyntheticsComponents(geoh5) - chargeability = geoh5.get_entity("Iteration_0_ip")[0] + chargeability = geoh5.get_entity("Iteration_0_chargeability")[0] # Run the inverse params = IPBatch2DInversionOptions.build( diff --git a/tests/run_tests/driver_ip_test.py b/tests/run_tests/driver_ip_test.py index b3d444d8..eaf2279a 100644 --- a/tests/run_tests/driver_ip_test.py +++ b/tests/run_tests/driver_ip_test.py @@ -82,7 +82,7 @@ def test_ip_3d_run( workpath = tmp_path.parent / "test_ip_3d_fwr_run0" / "inversion_test.ui.geoh5" with Workspace(workpath) as geoh5: - potential = geoh5.get_entity("Iteration_0_ip")[0] + potential = geoh5.get_entity("Iteration_0_chargeability")[0] mesh = geoh5.get_entity("mesh")[0] topography = geoh5.get_entity("topography")[0] diff --git a/tests/run_tests/driver_tile_estimator_test.py b/tests/run_tests/driver_tile_estimator_test.py index ea3edcc6..70b40e84 100644 --- a/tests/run_tests/driver_tile_estimator_test.py +++ b/tests/run_tests/driver_tile_estimator_test.py @@ -67,9 +67,8 @@ def test_tile_estimator_run( tile_params = TileParameters(geoh5=geoh5, simulation=driver.out_group) estimator = TileEstimator(tile_params) - assert len(estimator.get_results(max_tiles=32)) == 8 - with geoh5.open(): + assert len(estimator.get_results(max_tiles=32)) == 8 simpeg_group = estimator.run() driver = simpeg_group_to_driver(simpeg_group, geoh5)