Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions src/rules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,13 +1006,12 @@ fn assess_treatment_failure(
}

// Check if there's a current infection and bacteria level recorded at drug start
if individual.level[bacteria_idx] <= 0.0
|| individual.bacteria_level_at_drug_start[bacteria_idx].is_none()
{
let Some(bacteria_initial_level) = individual.bacteria_level_at_drug_start[bacteria_idx] else {
return false;
};
if individual.level[bacteria_idx] <= 0.0 {
return false;
}

let bacteria_initial_level = individual.bacteria_level_at_drug_start[bacteria_idx].unwrap();
let current_level = individual.level[bacteria_idx];

// Get failure threshold (default 0.5 = 50% of initial level)
Expand Down Expand Up @@ -2038,8 +2037,7 @@ pub(crate) fn apply_rules(
// --- end region travel updates ---

// --- sepsis risk ---
for &bacteria in BACTERIA_LIST.iter() {
let b_idx = *bacteria_indices.get(bacteria).unwrap();
for (b_idx, &bacteria) in BACTERIA_LIST.iter().enumerate() {
let current_level = individual.level[b_idx];

if current_level > 0.0 {
Expand Down Expand Up @@ -4454,9 +4452,7 @@ pub(crate) fn apply_rules(

// --- sepsis recovery logic (applied after death risk, only if individual is alive) ---
if individual.date_of_death.is_none() {
for &bacteria in BACTERIA_LIST.iter() {
let b_idx = *bacteria_indices.get(bacteria).unwrap();

for b_idx in 0..BACTERIA_LIST.len() {
// Only consider recovery if individual currently has sepsis from this bacteria
if individual.sepsis[b_idx] {
// Drop lingering sepsis once the triggering infection has cleared
Expand Down Expand Up @@ -5241,8 +5237,7 @@ pub(crate) fn apply_rules(
// Provenance bookkeeping disabled for memory-saving runs: keep the
// underlying biology, but do not store dense per-drug source labels.
if crate::simulation::population::TRACK_RESISTANCE_ACQUISITION_PROVENANCE {
for drug_name_static in DRUG_SHORT_NAMES.iter() {
let d_idx = *drug_indices.get(drug_name_static).unwrap();
for d_idx in 0..DRUG_SHORT_NAMES.len() {
// Check if any mechanism applicable to this drug is now set
let has_any_relevant_mechanism = ResistanceMechanism::all()
.iter()
Expand All @@ -5262,8 +5257,7 @@ pub(crate) fn apply_rules(
// Resistance floor: apply minimum resistance level for rare bacteria
// by ensuring mechanism_any is set where prevalence floor applies.
// (This preserves the floor semantics without injecting float values.)
for drug_name_static in DRUG_SHORT_NAMES.iter() {
let d_idx = *drug_indices.get(drug_name_static).unwrap();
for (d_idx, drug_name_static) in DRUG_SHORT_NAMES.iter().enumerate() {
let floor_level = calculate_resistance_floor(
bacteria,
drug_name_static,
Expand Down Expand Up @@ -6774,5 +6768,3 @@ impl FastMath for f64 {
fast_math::log2(self as f32) as f64 * std::f64::consts::LN_2
}
}


57 changes: 26 additions & 31 deletions src/simulation/journey_logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Mutex, OnceLock};
use std::sync::{Mutex, MutexGuard, OnceLock};

use crate::config::parameter_store;

Expand All @@ -41,21 +41,27 @@ fn activity_cache() -> &'static Mutex<HashMap<(usize, usize), CachedActivitySnap
ACTIVITY_SNAPSHOT_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}

fn lock_recovering_poison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
mutex
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}

/// Enable snapshot tracking and clear any stale cache entries from prior runs.
pub(crate) fn enable_activity_snapshots() {
ACTIVITY_SNAPSHOT_ENABLED.store(true, Ordering::Relaxed);
activity_registry().lock().unwrap().clear();
activity_cache().lock().unwrap().clear();
lock_recovering_poison(activity_registry()).clear();
lock_recovering_poison(activity_cache()).clear();
}

/// Disable snapshot tracking and drop any cached state.
pub(crate) fn disable_activity_snapshots() {
ACTIVITY_SNAPSHOT_ENABLED.store(false, Ordering::Relaxed);
if let Some(registry) = ACTIVITY_SNAPSHOT_REGISTRY.get() {
registry.lock().unwrap().clear();
lock_recovering_poison(registry).clear();
}
if let Some(cache) = ACTIVITY_SNAPSHOT_CACHE.get() {
cache.lock().unwrap().clear();
lock_recovering_poison(cache).clear();
}
}

Expand All @@ -64,22 +70,16 @@ pub(crate) fn register_activity_snapshot_tracking(individual_id: usize, bacteria
if !ACTIVITY_SNAPSHOT_ENABLED.load(Ordering::Relaxed) {
return;
}
activity_registry()
.lock()
.unwrap()
.insert((individual_id, bacteria_idx));
lock_recovering_poison(activity_registry()).insert((individual_id, bacteria_idx));
}

/// Remove tracking metadata when a journey completes or is aborted.
pub(crate) fn unregister_activity_snapshot_tracking(individual_id: usize, bacteria_idx: usize) {
if let Some(registry) = ACTIVITY_SNAPSHOT_REGISTRY.get() {
registry
.lock()
.unwrap()
.remove(&(individual_id, bacteria_idx));
lock_recovering_poison(registry).remove(&(individual_id, bacteria_idx));
}
if let Some(cache) = ACTIVITY_SNAPSHOT_CACHE.get() {
cache.lock().unwrap().remove(&(individual_id, bacteria_idx));
lock_recovering_poison(cache).remove(&(individual_id, bacteria_idx));
}
}

Expand All @@ -94,10 +94,7 @@ pub(crate) fn should_cache_pre_clearance_activity(
ACTIVITY_SNAPSHOT_REGISTRY
.get()
.map(|registry| {
registry
.lock()
.unwrap()
.contains(&(individual_id, bacteria_idx))
lock_recovering_poison(registry).contains(&(individual_id, bacteria_idx))
})
.unwrap_or(false)
}
Expand All @@ -112,7 +109,7 @@ pub(crate) fn cache_pre_clearance_activity(
if !should_cache_pre_clearance_activity(individual_id, bacteria_idx) {
return;
}
activity_cache().lock().unwrap().insert(
lock_recovering_poison(activity_cache()).insert(
(individual_id, bacteria_idx),
CachedActivitySnapshot {
time_step,
Expand All @@ -130,7 +127,7 @@ pub(crate) fn take_cached_activity_snapshot(
if !ACTIVITY_SNAPSHOT_ENABLED.load(Ordering::Relaxed) {
return None;
}
let mut cache = activity_cache().lock().unwrap();
let mut cache = lock_recovering_poison(activity_cache());
if let Some(entry) = cache.remove(&(individual_id, bacteria_idx)) {
if entry.time_step == expected_time_step {
return Some(entry.values);
Expand Down Expand Up @@ -428,17 +425,15 @@ impl JourneyLogger {

// Write initial snapshot to CSV
if let Some(ref mut writer) = self.csv_writer {
if let Some(snapshot) = self
.active_journeys
.get(&individual.id)
.unwrap()
.snapshots
.last()
{
if let Err(e) = writeln!(writer, "{}", JourneyLogger::snapshot_to_csv(snapshot)) {
eprintln!("Error writing snapshot: {}", e);
} else {
self.total_snapshots_logged += 1;
if let Some(journey) = self.active_journeys.get(&individual.id) {
if let Some(snapshot) = journey.snapshots.last() {
if let Err(e) =
writeln!(writer, "{}", JourneyLogger::snapshot_to_csv(snapshot))
{
eprintln!("Error writing snapshot: {}", e);
} else {
self.total_snapshots_logged += 1;
}
}
}
} else {
Expand Down
52 changes: 48 additions & 4 deletions src/simulation/simulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,20 +862,28 @@ impl MechanismCache {
}

if let Some(weights) = weights {
if weights.len() != slot.len()
|| weights
.iter()
.any(|weight| *weight < 0.0 || !weight.is_finite())
{
return None;
}

let total_weight: f64 = weights.iter().sum();
if total_weight <= 0.0 {
if !(total_weight > 0.0 && total_weight.is_finite()) {
return None;
}

let roll = rng.gen_range(0.0..total_weight);
let mut cumulative = 0.0_f64;
for (i, &weight) in weights.iter().enumerate() {
for (&profile, &weight) in slot.iter().zip(weights.iter()) {
cumulative += weight;
if roll < cumulative {
return Some(slot[i]);
return Some(profile);
}
}
return Some(*slot.last().unwrap());
return slot.last().copied();
}

let idx = rng.gen_range(0..slot.len());
Expand Down Expand Up @@ -5992,3 +6000,39 @@ impl Simulation {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::MechanismCache;
use rand::rngs::SmallRng;
use rand::SeedableRng;

#[test]
fn weighted_profile_sampling_rejects_mismatched_or_invalid_weights() {
let slot = [1_u64, 2, 3];
let mut rng = SmallRng::seed_from_u64(42);

assert_eq!(
MechanismCache::sample_from_slot(&slot, Some(&[1.0, 2.0]), &mut rng),
None
);
assert_eq!(
MechanismCache::sample_from_slot(&slot, Some(&[1.0, f64::NAN, 1.0]), &mut rng),
None
);
assert_eq!(
MechanismCache::sample_from_slot(&slot, Some(&[1.0, -1.0, 1.0]), &mut rng),
None
);
}

#[test]
fn weighted_profile_sampling_accepts_valid_weights() {
let slot = [10_u64, 20, 30];
let mut rng = SmallRng::seed_from_u64(7);

let sampled = MechanismCache::sample_from_slot(&slot, Some(&[0.0, 0.0, 1.0]), &mut rng);

assert_eq!(sampled, Some(30));
}
}
39 changes: 36 additions & 3 deletions tests/dimension_invariants.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use amr_project::simulation::population::{
Individual, InfectionResolutionType, Population, Region, ResistanceMechanism,
BACTERIA_CARRIAGE_COMPARTMENTS, BACTERIA_COUNT, BACTERIA_GROUPS, BACTERIA_LIST,
DRUG_SHORT_NAMES, MICROBIOME_RESISTANCE_LEVEL_COUNT, TRACK_RESISTANCE_ACQUISITION_PROVENANCE,
drug_class_for_drug, DrugClass, Individual, InfectionResolutionType, Population, Region,
ResistanceMechanism, BACTERIA_CARRIAGE_COMPARTMENTS, BACTERIA_COUNT, BACTERIA_GROUPS,
BACTERIA_LIST, DRUG_CLASS_LOOKUP, DRUG_SHORT_NAMES, MICROBIOME_RESISTANCE_LEVEL_COUNT,
TRACK_RESISTANCE_ACQUISITION_PROVENANCE,
};
use amr_project::simulation::simulation::{CalibrationMode, Simulation};
use rand::rngs::SmallRng;
Expand Down Expand Up @@ -280,6 +281,38 @@ fn static_model_dimensions_are_unique_and_consistent() {
assert_unique("BACTERIA_LIST", &BACTERIA_LIST);
assert_unique("DRUG_SHORT_NAMES", DRUG_SHORT_NAMES);

assert_len(
"DrugClass::all",
DrugClass::all().len(),
DrugClass::NUM_CLASSES,
);
let drug_class_names: Vec<&str> = DrugClass::all().iter().map(DrugClass::as_str).collect();
assert_unique("DrugClass::all", &drug_class_names);
for (expected_idx, drug_class) in DrugClass::all().iter().enumerate() {
assert_eq!(
drug_class.index(),
expected_idx,
"DrugClass::all order should match enum indices"
);
}

assert_len(
"DRUG_CLASS_LOOKUP",
DRUG_CLASS_LOOKUP.len(),
DRUG_SHORT_NAMES.len(),
);
for (drug_idx, class_idx) in DRUG_CLASS_LOOKUP.iter().enumerate() {
assert!(
*class_idx < DrugClass::NUM_CLASSES,
"DRUG_CLASS_LOOKUP[{drug_idx}] contains out-of-range drug class index {class_idx}"
);
assert_eq!(
*class_idx,
drug_class_for_drug(drug_idx).index(),
"DRUG_CLASS_LOOKUP[{drug_idx}] should match drug_class_for_drug"
);
}

let mechanism_names: Vec<&str> = ResistanceMechanism::all()
.iter()
.map(ResistanceMechanism::as_str)
Expand Down
Loading