diff --git a/simpeg_drivers-assets/uijson/direct_current_2d_forward.ui.json b/simpeg_drivers-assets/uijson/direct_current_2d_forward.ui.json index 121b959e..a166847d 100644 --- a/simpeg_drivers-assets/uijson/direct_current_2d_forward.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_2d_forward.ui.json @@ -217,6 +217,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "max_chunk_size": { "min": 0, diff --git a/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json b/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json index 97f6c6fc..2ae3f547 100644 --- a/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_2d_inversion.ui.json @@ -546,6 +546,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "store_sensitivities": { "choiceList": [ diff --git a/simpeg_drivers-assets/uijson/direct_current_3d_forward.ui.json b/simpeg_drivers-assets/uijson/direct_current_3d_forward.ui.json index bd0bc3a6..1d7c233d 100644 --- a/simpeg_drivers-assets/uijson/direct_current_3d_forward.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_3d_forward.ui.json @@ -122,6 +122,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json b/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json index 5ba10bda..3f2da74b 100644 --- a/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_3d_inversion.ui.json @@ -494,6 +494,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/direct_current_batch2d_forward.ui.json b/simpeg_drivers-assets/uijson/direct_current_batch2d_forward.ui.json index 494eba23..850d01d7 100644 --- a/simpeg_drivers-assets/uijson/direct_current_batch2d_forward.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_batch2d_forward.ui.json @@ -175,6 +175,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "max_chunk_size": { "min": 0, diff --git a/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json b/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json index c095e5be..9b37d4cb 100644 --- a/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/direct_current_batch2d_inversion.ui.json @@ -503,6 +503,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "store_sensitivities": { "choiceList": [ diff --git a/simpeg_drivers-assets/uijson/fem_forward.ui.json b/simpeg_drivers-assets/uijson/fem_forward.ui.json index dd61bf37..5045ce80 100644 --- a/simpeg_drivers-assets/uijson/fem_forward.ui.json +++ b/simpeg_drivers-assets/uijson/fem_forward.ui.json @@ -135,6 +135,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/fem_inversion.ui.json b/simpeg_drivers-assets/uijson/fem_inversion.ui.json index 5c8f2815..317f5420 100644 --- a/simpeg_drivers-assets/uijson/fem_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/fem_inversion.ui.json @@ -530,6 +530,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/induced_polarization_2d_forward.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_2d_forward.ui.json index b1bef415..411310f5 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_2d_forward.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_2d_forward.ui.json @@ -228,6 +228,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "max_chunk_size": { "min": 0, diff --git a/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json index eb0313f7..89bb2891 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_2d_inversion.ui.json @@ -556,6 +556,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "max_ram": "", "store_sensitivities": { diff --git a/simpeg_drivers-assets/uijson/induced_polarization_3d_forward.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_3d_forward.ui.json index edea4436..b4839acf 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_3d_forward.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_3d_forward.ui.json @@ -162,6 +162,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json index 1e5a55b1..a2ea78f4 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_3d_inversion.ui.json @@ -534,6 +534,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/induced_polarization_batch2d_forward.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_batch2d_forward.ui.json index 6944322e..76415339 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_batch2d_forward.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_batch2d_forward.ui.json @@ -212,6 +212,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "max_chunk_size": { "min": 0, diff --git a/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json b/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json index 94e27c4d..11d17248 100644 --- a/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/induced_polarization_batch2d_inversion.ui.json @@ -540,6 +540,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": 1, "store_sensitivities": { "choiceList": [ diff --git a/simpeg_drivers-assets/uijson/magnetotellurics_forward.ui.json b/simpeg_drivers-assets/uijson/magnetotellurics_forward.ui.json index 06ad1317..1d0634ef 100644 --- a/simpeg_drivers-assets/uijson/magnetotellurics_forward.ui.json +++ b/simpeg_drivers-assets/uijson/magnetotellurics_forward.ui.json @@ -202,6 +202,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json b/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json index 05b124c9..c68922fc 100644 --- a/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/magnetotellurics_inversion.ui.json @@ -741,6 +741,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/tdem_forward.ui.json b/simpeg_drivers-assets/uijson/tdem_forward.ui.json index 40987aa8..b8afa77c 100644 --- a/simpeg_drivers-assets/uijson/tdem_forward.ui.json +++ b/simpeg_drivers-assets/uijson/tdem_forward.ui.json @@ -154,6 +154,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/tdem_inversion.ui.json b/simpeg_drivers-assets/uijson/tdem_inversion.ui.json index 979b6ae2..a47a67d9 100644 --- a/simpeg_drivers-assets/uijson/tdem_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/tdem_inversion.ui.json @@ -567,6 +567,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/tipper_forward.ui.json b/simpeg_drivers-assets/uijson/tipper_forward.ui.json index 4b760f57..8ffcd7f3 100644 --- a/simpeg_drivers-assets/uijson/tipper_forward.ui.json +++ b/simpeg_drivers-assets/uijson/tipper_forward.ui.json @@ -178,6 +178,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers-assets/uijson/tipper_inversion.ui.json b/simpeg_drivers-assets/uijson/tipper_inversion.ui.json index 81cff4bd..654182fa 100644 --- a/simpeg_drivers-assets/uijson/tipper_inversion.ui.json +++ b/simpeg_drivers-assets/uijson/tipper_inversion.ui.json @@ -621,6 +621,16 @@ "value": 1, "visible": false }, + "solver_type": { + "choiceList": [ + "Pardiso", + "Mumps" + ], + "group": "Compute", + "label": "Direct solver", + "tooltip": "Direct solver to use for the forward calculations", + "value": "Pardiso" + }, "tile_spatial": { "group": "Compute", "label": "Number of tiles", diff --git a/simpeg_drivers/components/factories/simulation_factory.py b/simpeg_drivers/components/factories/simulation_factory.py index 2abd35f8..f62304b8 100644 --- a/simpeg_drivers/components/factories/simulation_factory.py +++ b/simpeg_drivers/components/factories/simulation_factory.py @@ -53,7 +53,7 @@ def __init__(self, params: BaseParams | BaseOptions): ]: import pymatsolver.direct as solver_module - self.solver = solver_module.Pardiso + self.solver = getattr(solver_module, params.solver_type) def concrete_object(self): if self.factory_type in ["magnetic scalar", "magnetic vector"]: diff --git a/simpeg_drivers/params.py b/simpeg_drivers/params.py index 01d74697..52954334 100644 --- a/simpeg_drivers/params.py +++ b/simpeg_drivers/params.py @@ -12,6 +12,7 @@ from __future__ import annotations import multiprocessing +from enum import Enum from logging import getLogger from pathlib import Path from typing import ClassVar, TypeAlias @@ -56,6 +57,15 @@ def at_least_one(cls, data): return data +class SolverType(str, Enum): + """ + Supported solvers. + """ + + Pardiso = "Pardiso" + Mumps = "Mumps" + + class CoreOptions(BaseData): """ Core parameters shared by inverse and forward operations. @@ -94,6 +104,7 @@ class CoreOptions(BaseData): mesh: Octree | DrapeModel | None starting_model: float | FloatData active_cells: ActiveCellsOptions + solver_type: SolverType = SolverType.Pardiso tile_spatial: int = 1 parallelized: bool = True n_cpu: int | None = None diff --git a/tests/run_tests/driver_airborne_tem_test.py b/tests/run_tests/driver_airborne_tem_test.py index 819fb8ff..1cc9c333 100644 --- a/tests/run_tests/driver_airborne_tem_test.py +++ b/tests/run_tests/driver_airborne_tem_test.py @@ -10,14 +10,12 @@ from __future__ import annotations -import sys from pathlib import Path import numpy as np from geoh5py.groups import SimPEGGroup from geoh5py.workspace import Workspace -from pymatsolver.direct import Mumps -from pytest import mark, raises +from pytest import raises from simpeg_drivers.electromagnetics.time_domain import ( TDEMForwardOptions, @@ -100,17 +98,14 @@ def test_airborne_tem_fwr_run( x_channel_bool=True, y_channel_bool=True, z_channel_bool=True, + solver_type="Mumps", ) fwr_driver = TDEMForwardDriver(params) - fwr_driver.data_misfit.objfcts[0].simulation.solver = Mumps fwr_driver.run() -@mark.skipif( - sys.platform.startswith("win"), reason="Skipping windows-only tests due to mkl 2024" -) def test_airborne_tem_run(tmp_path: Path, max_iterations=1, pytest=True): workpath = tmp_path / "inversion_test.ui.geoh5" if pytest: @@ -190,12 +185,12 @@ def test_airborne_tem_run(tmp_path: Path, max_iterations=1, pytest=True): prctile=5, sens_wts_threshold=1.0, store_sensitivities="ram", + solver_type="Mumps", **data_kwargs, ) params.write_ui_json(path=tmp_path / "Inv_run.ui.json") driver = TDEMInversionDriver(params) - driver.data_misfit.objfcts[0].simulation.solver = Mumps driver.run() with geoh5.open() as run_ws: diff --git a/tests/run_tests/driver_ground_tem_test.py b/tests/run_tests/driver_ground_tem_test.py index 1aa6260e..c92273ca 100644 --- a/tests/run_tests/driver_ground_tem_test.py +++ b/tests/run_tests/driver_ground_tem_test.py @@ -10,14 +10,12 @@ from __future__ import annotations -import sys from logging import INFO, getLogger from pathlib import Path import numpy as np from geoh5py.workspace import Workspace from pymatsolver.direct import Mumps -from pytest import mark from simpeg_drivers.electromagnetics.time_domain import ( TDEMForwardOptions, @@ -118,6 +116,7 @@ def test_ground_tem_fwr_run( x_channel_bool=True, y_channel_bool=True, z_channel_bool=True, + solver_type="Mumps", ) fwr_driver = TDEMForwardDriver(params) @@ -134,13 +133,10 @@ def test_ground_tem_fwr_run( assert "closed" in caplog.records[0].message - fwr_driver.data_misfit.objfcts[0].simulation.solver = Mumps + assert fwr_driver.data_misfit.objfcts[0].simulation.simulations[0].solver == Mumps fwr_driver.run() -@mark.skipif( - sys.platform.startswith("win"), reason="Skipping windows-only tests due to mkl 2024" -) def test_ground_tem_run(tmp_path: Path, max_iterations=1, pytest=True): workpath = tmp_path / "inversion_test.ui.geoh5" if pytest: @@ -218,12 +214,12 @@ def test_ground_tem_run(tmp_path: Path, max_iterations=1, pytest=True): prctile=100, sens_wts_threshold=1.0, store_sensitivities="ram", + solver_type="Mumps", **data_kwargs, ) params.write_ui_json(path=tmp_path / "Inv_run.ui.json") driver = TDEMInversionDriver(params) - driver.data_misfit.objfcts[0].simulation.solver = Mumps driver.run() with geoh5.open() as run_ws: