diff --git a/simpeg_drivers/components/factories/survey_factory.py b/simpeg_drivers/components/factories/survey_factory.py index 2f7c2c2e..7f26049e 100644 --- a/simpeg_drivers/components/factories/survey_factory.py +++ b/simpeg_drivers/components/factories/survey_factory.py @@ -203,13 +203,12 @@ def _dcip_arguments(self, data=None): def _tdem_arguments(self, data=None): receivers = data.entity transmitters = receivers.transmitters + channels = np.array(receivers.channels) * self.params.unit_conversion - if receivers.channels[-1] > ( - receivers.waveform[:, 0].max() - receivers.timing_mark - ): + if any(channels > (self.params.time_steps.sum() - self.params.timing_mark)): raise GeoAppsError( f"The latest time channel {receivers.channels[-1]} exceeds " - f"the waveform discretization. Revise waveform." + f"the waveform discretization. Check waveform sampling from start to end." ) if isinstance(transmitters, LargeLoopGroundTEMTransmitters): @@ -239,6 +238,16 @@ def _tdem_arguments(self, data=None): receivers.waveform[:, 0] - receivers.timing_mark ) * self.params.unit_conversion + # Check single channel per time gate + _, count = np.unique( + np.searchsorted(wave_times, channels, side="right"), return_counts=True + ) + if np.any(count > 1): + raise GeoAppsError( + "Multiple channels found within single time step. " + "Check waveform sampling on the off-times." + ) + if "1d" in self.factory_type: on_times = wave_times <= 0.0 waveform = tdem.sources.PiecewiseLinearWaveform( diff --git a/tests/run_tests/driver_airborne_tem_test.py b/tests/run_tests/driver_airborne_tem_test.py index fe5b426b..4d9cf777 100644 --- a/tests/run_tests/driver_airborne_tem_test.py +++ b/tests/run_tests/driver_airborne_tem_test.py @@ -74,6 +74,12 @@ def test_bad_waveform(tmp_path: Path): with raises(GeoAppsError, match="The latest time"): _ = fwr_driver.inversion_data.survey + with geoh5: + params.data_object.channels[-1] = 0.7 + + with raises(GeoAppsError, match="Multiple channels found within single time step"): + _ = fwr_driver.inversion_data.survey + def test_airborne_tem_fwr_run( tmp_path: Path,