diff --git a/src/rules/mod.rs b/src/rules/mod.rs index 53584c2..6ecdd9e 100644 --- a/src/rules/mod.rs +++ b/src/rules/mod.rs @@ -1006,13 +1006,12 @@ fn assess_treatment_failure( } // Check if there's a current infection and bacteria level recorded at drug start - if individual.level[bacteria_idx] <= 0.0 - || individual.bacteria_level_at_drug_start[bacteria_idx].is_none() - { + let Some(bacteria_initial_level) = individual.bacteria_level_at_drug_start[bacteria_idx] else { + return false; + }; + if individual.level[bacteria_idx] <= 0.0 { return false; } - - let bacteria_initial_level = individual.bacteria_level_at_drug_start[bacteria_idx].unwrap(); let current_level = individual.level[bacteria_idx]; // Get failure threshold (default 0.5 = 50% of initial level) @@ -2038,8 +2037,7 @@ pub(crate) fn apply_rules( // --- end region travel updates --- // --- sepsis risk --- - for &bacteria in BACTERIA_LIST.iter() { - let b_idx = *bacteria_indices.get(bacteria).unwrap(); + for (b_idx, &bacteria) in BACTERIA_LIST.iter().enumerate() { let current_level = individual.level[b_idx]; if current_level > 0.0 { @@ -4454,9 +4452,7 @@ pub(crate) fn apply_rules( // --- sepsis recovery logic (applied after death risk, only if individual is alive) --- if individual.date_of_death.is_none() { - for &bacteria in BACTERIA_LIST.iter() { - let b_idx = *bacteria_indices.get(bacteria).unwrap(); - + for b_idx in 0..BACTERIA_LIST.len() { // Only consider recovery if individual currently has sepsis from this bacteria if individual.sepsis[b_idx] { // Drop lingering sepsis once the triggering infection has cleared @@ -5241,8 +5237,7 @@ pub(crate) fn apply_rules( // Provenance bookkeeping disabled for memory-saving runs: keep the // underlying biology, but do not store dense per-drug source labels. if crate::simulation::population::TRACK_RESISTANCE_ACQUISITION_PROVENANCE { - for drug_name_static in DRUG_SHORT_NAMES.iter() { - let d_idx = *drug_indices.get(drug_name_static).unwrap(); + for d_idx in 0..DRUG_SHORT_NAMES.len() { // Check if any mechanism applicable to this drug is now set let has_any_relevant_mechanism = ResistanceMechanism::all() .iter() @@ -5262,8 +5257,7 @@ pub(crate) fn apply_rules( // Resistance floor: apply minimum resistance level for rare bacteria // by ensuring mechanism_any is set where prevalence floor applies. // (This preserves the floor semantics without injecting float values.) - for drug_name_static in DRUG_SHORT_NAMES.iter() { - let d_idx = *drug_indices.get(drug_name_static).unwrap(); + for (d_idx, drug_name_static) in DRUG_SHORT_NAMES.iter().enumerate() { let floor_level = calculate_resistance_floor( bacteria, drug_name_static, @@ -6774,5 +6768,3 @@ impl FastMath for f64 { fast_math::log2(self as f32) as f64 * std::f64::consts::LN_2 } } - - diff --git a/src/simulation/journey_logger.rs b/src/simulation/journey_logger.rs index ab9c66b..71605fc 100644 --- a/src/simulation/journey_logger.rs +++ b/src/simulation/journey_logger.rs @@ -17,7 +17,7 @@ use std::fs::File; use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Mutex, OnceLock}; +use std::sync::{Mutex, MutexGuard, OnceLock}; use crate::config::parameter_store; @@ -41,21 +41,27 @@ fn activity_cache() -> &'static Mutex(mutex: &Mutex) -> MutexGuard<'_, T> { + mutex + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + /// Enable snapshot tracking and clear any stale cache entries from prior runs. pub(crate) fn enable_activity_snapshots() { ACTIVITY_SNAPSHOT_ENABLED.store(true, Ordering::Relaxed); - activity_registry().lock().unwrap().clear(); - activity_cache().lock().unwrap().clear(); + lock_recovering_poison(activity_registry()).clear(); + lock_recovering_poison(activity_cache()).clear(); } /// Disable snapshot tracking and drop any cached state. pub(crate) fn disable_activity_snapshots() { ACTIVITY_SNAPSHOT_ENABLED.store(false, Ordering::Relaxed); if let Some(registry) = ACTIVITY_SNAPSHOT_REGISTRY.get() { - registry.lock().unwrap().clear(); + lock_recovering_poison(registry).clear(); } if let Some(cache) = ACTIVITY_SNAPSHOT_CACHE.get() { - cache.lock().unwrap().clear(); + lock_recovering_poison(cache).clear(); } } @@ -64,22 +70,16 @@ pub(crate) fn register_activity_snapshot_tracking(individual_id: usize, bacteria if !ACTIVITY_SNAPSHOT_ENABLED.load(Ordering::Relaxed) { return; } - activity_registry() - .lock() - .unwrap() - .insert((individual_id, bacteria_idx)); + lock_recovering_poison(activity_registry()).insert((individual_id, bacteria_idx)); } /// Remove tracking metadata when a journey completes or is aborted. pub(crate) fn unregister_activity_snapshot_tracking(individual_id: usize, bacteria_idx: usize) { if let Some(registry) = ACTIVITY_SNAPSHOT_REGISTRY.get() { - registry - .lock() - .unwrap() - .remove(&(individual_id, bacteria_idx)); + lock_recovering_poison(registry).remove(&(individual_id, bacteria_idx)); } if let Some(cache) = ACTIVITY_SNAPSHOT_CACHE.get() { - cache.lock().unwrap().remove(&(individual_id, bacteria_idx)); + lock_recovering_poison(cache).remove(&(individual_id, bacteria_idx)); } } @@ -94,10 +94,7 @@ pub(crate) fn should_cache_pre_clearance_activity( ACTIVITY_SNAPSHOT_REGISTRY .get() .map(|registry| { - registry - .lock() - .unwrap() - .contains(&(individual_id, bacteria_idx)) + lock_recovering_poison(registry).contains(&(individual_id, bacteria_idx)) }) .unwrap_or(false) } @@ -112,7 +109,7 @@ pub(crate) fn cache_pre_clearance_activity( if !should_cache_pre_clearance_activity(individual_id, bacteria_idx) { return; } - activity_cache().lock().unwrap().insert( + lock_recovering_poison(activity_cache()).insert( (individual_id, bacteria_idx), CachedActivitySnapshot { time_step, @@ -130,7 +127,7 @@ pub(crate) fn take_cached_activity_snapshot( if !ACTIVITY_SNAPSHOT_ENABLED.load(Ordering::Relaxed) { return None; } - let mut cache = activity_cache().lock().unwrap(); + let mut cache = lock_recovering_poison(activity_cache()); if let Some(entry) = cache.remove(&(individual_id, bacteria_idx)) { if entry.time_step == expected_time_step { return Some(entry.values); @@ -428,17 +425,15 @@ impl JourneyLogger { // Write initial snapshot to CSV if let Some(ref mut writer) = self.csv_writer { - if let Some(snapshot) = self - .active_journeys - .get(&individual.id) - .unwrap() - .snapshots - .last() - { - if let Err(e) = writeln!(writer, "{}", JourneyLogger::snapshot_to_csv(snapshot)) { - eprintln!("Error writing snapshot: {}", e); - } else { - self.total_snapshots_logged += 1; + if let Some(journey) = self.active_journeys.get(&individual.id) { + if let Some(snapshot) = journey.snapshots.last() { + if let Err(e) = + writeln!(writer, "{}", JourneyLogger::snapshot_to_csv(snapshot)) + { + eprintln!("Error writing snapshot: {}", e); + } else { + self.total_snapshots_logged += 1; + } } } } else { diff --git a/src/simulation/simulation.rs b/src/simulation/simulation.rs index f55eddb..71bf0b7 100644 --- a/src/simulation/simulation.rs +++ b/src/simulation/simulation.rs @@ -862,20 +862,28 @@ impl MechanismCache { } if let Some(weights) = weights { + if weights.len() != slot.len() + || weights + .iter() + .any(|weight| *weight < 0.0 || !weight.is_finite()) + { + return None; + } + let total_weight: f64 = weights.iter().sum(); - if total_weight <= 0.0 { + if !(total_weight > 0.0 && total_weight.is_finite()) { return None; } let roll = rng.gen_range(0.0..total_weight); let mut cumulative = 0.0_f64; - for (i, &weight) in weights.iter().enumerate() { + for (&profile, &weight) in slot.iter().zip(weights.iter()) { cumulative += weight; if roll < cumulative { - return Some(slot[i]); + return Some(profile); } } - return Some(*slot.last().unwrap()); + return slot.last().copied(); } let idx = rng.gen_range(0..slot.len()); @@ -5992,3 +6000,39 @@ impl Simulation { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::MechanismCache; + use rand::rngs::SmallRng; + use rand::SeedableRng; + + #[test] + fn weighted_profile_sampling_rejects_mismatched_or_invalid_weights() { + let slot = [1_u64, 2, 3]; + let mut rng = SmallRng::seed_from_u64(42); + + assert_eq!( + MechanismCache::sample_from_slot(&slot, Some(&[1.0, 2.0]), &mut rng), + None + ); + assert_eq!( + MechanismCache::sample_from_slot(&slot, Some(&[1.0, f64::NAN, 1.0]), &mut rng), + None + ); + assert_eq!( + MechanismCache::sample_from_slot(&slot, Some(&[1.0, -1.0, 1.0]), &mut rng), + None + ); + } + + #[test] + fn weighted_profile_sampling_accepts_valid_weights() { + let slot = [10_u64, 20, 30]; + let mut rng = SmallRng::seed_from_u64(7); + + let sampled = MechanismCache::sample_from_slot(&slot, Some(&[0.0, 0.0, 1.0]), &mut rng); + + assert_eq!(sampled, Some(30)); + } +} diff --git a/tests/dimension_invariants.rs b/tests/dimension_invariants.rs index 4fc0b0a..186e4dc 100644 --- a/tests/dimension_invariants.rs +++ b/tests/dimension_invariants.rs @@ -1,7 +1,8 @@ use amr_project::simulation::population::{ - Individual, InfectionResolutionType, Population, Region, ResistanceMechanism, - BACTERIA_CARRIAGE_COMPARTMENTS, BACTERIA_COUNT, BACTERIA_GROUPS, BACTERIA_LIST, - DRUG_SHORT_NAMES, MICROBIOME_RESISTANCE_LEVEL_COUNT, TRACK_RESISTANCE_ACQUISITION_PROVENANCE, + drug_class_for_drug, DrugClass, Individual, InfectionResolutionType, Population, Region, + ResistanceMechanism, BACTERIA_CARRIAGE_COMPARTMENTS, BACTERIA_COUNT, BACTERIA_GROUPS, + BACTERIA_LIST, DRUG_CLASS_LOOKUP, DRUG_SHORT_NAMES, MICROBIOME_RESISTANCE_LEVEL_COUNT, + TRACK_RESISTANCE_ACQUISITION_PROVENANCE, }; use amr_project::simulation::simulation::{CalibrationMode, Simulation}; use rand::rngs::SmallRng; @@ -280,6 +281,38 @@ fn static_model_dimensions_are_unique_and_consistent() { assert_unique("BACTERIA_LIST", &BACTERIA_LIST); assert_unique("DRUG_SHORT_NAMES", DRUG_SHORT_NAMES); + assert_len( + "DrugClass::all", + DrugClass::all().len(), + DrugClass::NUM_CLASSES, + ); + let drug_class_names: Vec<&str> = DrugClass::all().iter().map(DrugClass::as_str).collect(); + assert_unique("DrugClass::all", &drug_class_names); + for (expected_idx, drug_class) in DrugClass::all().iter().enumerate() { + assert_eq!( + drug_class.index(), + expected_idx, + "DrugClass::all order should match enum indices" + ); + } + + assert_len( + "DRUG_CLASS_LOOKUP", + DRUG_CLASS_LOOKUP.len(), + DRUG_SHORT_NAMES.len(), + ); + for (drug_idx, class_idx) in DRUG_CLASS_LOOKUP.iter().enumerate() { + assert!( + *class_idx < DrugClass::NUM_CLASSES, + "DRUG_CLASS_LOOKUP[{drug_idx}] contains out-of-range drug class index {class_idx}" + ); + assert_eq!( + *class_idx, + drug_class_for_drug(drug_idx).index(), + "DRUG_CLASS_LOOKUP[{drug_idx}] should match drug_class_for_drug" + ); + } + let mechanism_names: Vec<&str> = ResistanceMechanism::all() .iter() .map(ResistanceMechanism::as_str)