Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions simpeg_drivers/components/factories/directives_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from abc import ABC
from logging import getLogger
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -29,6 +30,8 @@
if TYPE_CHECKING:
from simpeg_drivers.driver import InversionDriver

logger = getLogger(__name__)


class DirectivesFactory:
def __init__(self, driver: InversionDriver):
Expand Down Expand Up @@ -263,7 +266,15 @@ def scale_misfits(self):
def update_irls_directive(self):
"""Directive to update IRLS."""
if self._update_irls_directive is None:
has_chi_start = self.params.starting_chi_factor is not None
start_chi_fact = self.params.starting_chi_factor

if start_chi_fact is not None and self.params.chi_factor > start_chi_fact:
logger.warning(
"Starting chi factor is greater than target chi factor.\n"
"Setting the target chi factor to the starting chi factor."
)
start_chi_fact = self.params.chi_factor

self._update_irls_directive = directives.UpdateIRLS(
f_min_change=self.params.f_min_change,
max_irls_iterations=self.params.max_irls_iterations,
Expand All @@ -272,11 +283,7 @@ def update_irls_directive(self):
cooling_rate=self.params.cooling_rate,
cooling_factor=self.params.cooling_factor,
irls_cooling_factor=self.params.epsilon_cooling_factor,
chifact_start=(
self.params.starting_chi_factor
if has_chi_start
else self.params.chi_factor
),
chifact_start=start_chi_fact or self.params.chi_factor,
chifact_target=self.params.chi_factor,
)
return self._update_irls_directive
Expand Down
2 changes: 1 addition & 1 deletion simpeg_drivers/joint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class BaseJointOptions(BaseData):
initial_beta: float | None = None
cooling_factor: float = 2.0

cooling_rate: float = 1.0
cooling_rate: int = 1
max_global_iterations: int = 50
max_line_search_iterations: int = 20
max_cg_iterations: int = 30
Expand Down
2 changes: 1 addition & 1 deletion simpeg_drivers/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class BaseInversionOptions(CoreOptions):
initial_beta: float | None = None
cooling_factor: float = 2.0

cooling_rate: float = 1.0
cooling_rate: int = 1
max_global_iterations: int = 50
max_line_search_iterations: int = 20
max_cg_iterations: int = 30
Expand Down
37 changes: 37 additions & 0 deletions tests/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,40 @@ def test_smallness_terms(tmp_path: Path):
params.alpha_s = None
driver = GravityInversionDriver(params)
assert driver.regularization.objfcts[0].alpha_s == 0.0


def test_target_chi(tmp_path: Path, caplog):
n_grid_points = 2
refinement = (2,)

geoh5, _, model, survey, topography = setup_inversion_workspace(
tmp_path,
background=0.0,
anomaly=0.75,
n_electrodes=n_grid_points,
n_lines=n_grid_points,
refinement=refinement,
flatten=False,
)

with geoh5.open():
gz = survey.add_data({"gz": {"values": np.ones(survey.n_vertices)}})
mesh = model.parent
active_cells = ActiveCellsOptions(topography_object=topography)
params = GravityInversionOptions(
geoh5=geoh5,
mesh=mesh,
active_cells=active_cells,
data_object=gz.parent,
gz_channel=gz,
gz_uncertainty=2e-3,
starting_model=1e-4,
starting_chi_factor=1.0,
chi_factor=2.0,
)
driver = GravityInversionDriver(params)

with caplog.at_level("WARNING"):
assert driver.directives.update_irls_directive.chifact_start == 2.0

assert "Starting chi factor is greater" in caplog.text