From c205da871937d3afdbcb2a8813fb70074c344f3a Mon Sep 17 00:00:00 2001 From: nnoori-IDM <42287387+nnoori-IDM@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:49:32 -0700 Subject: [PATCH 1/5] use subtraction instead of ~ on uids --- hiud_acceptance.py | 13 ++++++++++--- interventions.py | 2 +- run_care_hiud_sensitivity.py | 11 +++++++++-- run_care_rate_sensitivity.py | 11 +++++++++-- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/hiud_acceptance.py b/hiud_acceptance.py index 764b2e1..d325ccd 100644 --- a/hiud_acceptance.py +++ b/hiud_acceptance.py @@ -24,11 +24,18 @@ INTV_YEAR = 2026 N_SEEDS = 5 -# Three care-seeking scenarios +# Care-seeking scenarios +# CARE_SCENARIOS = { +# '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), +# '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), +# '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), +# } + +# anemic and pain coefficients held fixed; only base rate varies CARE_SCENARIOS = { - '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), + '10%': sc.objdict(base=0.10, anemic=0.86, pain=0.37), '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), - '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), + '35%': sc.objdict(base=0.35, anemic=0.86, pain=0.37), } CARE_COLORS = { diff --git a/interventions.py b/interventions.py index ed37c23..054f6b7 100644 --- a/interventions.py +++ b/interventions.py @@ -386,7 +386,7 @@ def check_continuation(self): self._p_continue.set(p_continue) continuers = self._p_continue.filter(on_treatment_uids) - stops = on_treatment_uids & ~continuers + stops = on_treatment_uids - continuers # Update last cycle tracking: 1.0 = resolved, 0.0 = persisted self.hmb_last_cycle[on_treatment_uids] = (~hmb_this_cycle).astype(float) diff --git a/run_care_hiud_sensitivity.py b/run_care_hiud_sensitivity.py index 4889049..94bb20a 100644 --- a/run_care_hiud_sensitivity.py +++ b/run_care_hiud_sensitivity.py @@ -44,10 +44,17 @@ FIXED_ACCEPT = 0.50 # Care-seeking scenarios +# CARE_SCENARIOS = { +# '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), +# '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), +# '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), +# } + +# anemic and pain coefficients held fixed; only base rate varies CARE_SCENARIOS = { - '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), + '10%': sc.objdict(base=0.10, anemic=0.86, pain=0.37), '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), - '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), + '35%': sc.objdict(base=0.35, anemic=0.86, pain=0.37), } # hIUD uptake scenarios (acceptance probabilities from calibration with 50% fixed accept) diff --git a/run_care_rate_sensitivity.py b/run_care_rate_sensitivity.py index 32a8aa4..582390c 100644 --- a/run_care_rate_sensitivity.py +++ b/run_care_rate_sensitivity.py @@ -40,10 +40,17 @@ INTV_YEAR = 2026 # Care-seeking scenarios: combined (anemia + pain) reaches 46% for all +# CARE_SCENARIOS = { +# '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), +# '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), +# '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), +# } + +# anemic and pain coefficients held fixed; only base rate varies CARE_SCENARIOS = { - '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), + '10%': sc.objdict(base=0.10, anemic=0.86, pain=0.37), '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), - '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), + '35%': sc.objdict(base=0.35, anemic=0.86, pain=0.37), } SCENARIO_LABELS = { From 5bc7db5644eb71acfcf4f21dffb2d56846ce1761 Mon Sep 17 00:00:00 2001 From: nnoori-IDM <42287387+nnoori-IDM@users.noreply.github.com> Date: Mon, 27 Apr 2026 08:58:17 -0700 Subject: [PATCH 2/5] develop a pool model of HMB treatment --- calibrate_p_hmb.py | 147 +++ hiud_acceptance.py | 258 ------ interventions.py | 16 +- interventions_pool.py | 1549 ++++++++++++++++++++++++++++++++ menstruation.py | 6 +- run_anemia_risk_sensitivity.py | 848 ++++++++--------- run_care_hiud_sensitivity.py | 1140 ----------------------- run_cascade.py | 6 +- run_scenarios.py | 649 +++++++++++++ stats_interventions2.py | 1019 +++++++++++++++++++++ 10 files changed, 3804 insertions(+), 1834 deletions(-) create mode 100644 calibrate_p_hmb.py delete mode 100644 hiud_acceptance.py create mode 100644 interventions_pool.py delete mode 100644 run_care_hiud_sensitivity.py create mode 100644 run_scenarios.py create mode 100644 stats_interventions2.py diff --git a/calibrate_p_hmb.py b/calibrate_p_hmb.py new file mode 100644 index 0000000..3ac819d --- /dev/null +++ b/calibrate_p_hmb.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Apr 23 09:58:52 2026 + +@author: navidehno +""" + +""" +Calibrate p_hmb_prone so that with status quo treatment, +observed HMB prevalence ≈ 48%. +""" + +import numpy as np +import sciris as sc +import starsim as ss +import fpsim as fp +import matplotlib.pyplot as plt + +from menstruation import Menstruation +from education import Education +from interventions_pool import HMBCounterfactual +from analyzers import track_hmb_anemia + +# ── Settings ── +START = 2020 +STOP = 2030 +N_SEEDS = 5 +TARGET_PREV = 0.48 + +# Values to sweep +P_HMB_PRONE_VALUES = np.arange(0.48, 0.61, 0.01) + + +def make_sim(p_hmb_prone, with_treatment=True, seed=0): + """Build sim with a specific p_hmb_prone value.""" + mens = Menstruation(pars=dict( + p_hmb_prone=ss.bernoulli(p=p_hmb_prone), + )) + edu = Education() + + sim_kwargs = dict( + start=START, stop=STOP, + n_agents=10000, total_pop=55_000_000, + location='kenya', + education_module=edu, + connectors=[mens], + analyzers=[track_hmb_anemia()], + rand_seed=seed, verbose=0, + ) + + if with_treatment: + # Status quo: 10% ever-seek, no hIUD, runs from 2020 + counterfactual = HMBCounterfactual() + sim_kwargs['interventions'] = [counterfactual] + + return fp.Sim(**sim_kwargs) + + +def run_calibration(): + results = {} + + for p_val in P_HMB_PRONE_VALUES: + p_key = f'{p_val:.2f}' + print(f"\np_hmb_prone = {p_val:.2f}") + prevs_with_tx = [] + prevs_no_tx = [] + + for seed in range(N_SEEDS): + print(f" seed {seed}...", end=" ", flush=True) + + # With status quo treatment + sim_tx = make_sim(p_val, with_treatment=True, seed=seed) + sim_tx.run() + hmb_prev_tx = np.asarray( + sim_tx.results.menstruation['hmb_prev'] + ) + # Take mean of last 12 months as the steady-state prevalence + prevs_with_tx.append(hmb_prev_tx[-12:].mean()) + + # Without any treatment (for comparison) + sim_no = make_sim(p_val, with_treatment=False, seed=seed) + sim_no.run() + hmb_prev_no = np.asarray( + sim_no.results.menstruation['hmb_prev'] + ) + prevs_no_tx.append(hmb_prev_no[-12:].mean()) + + del sim_tx, sim_no + print("done") + + results[p_key] = { + 'with_tx_mean': np.mean(prevs_with_tx), + 'with_tx_std': np.std(prevs_with_tx), + 'no_tx_mean': np.mean(prevs_no_tx), + 'no_tx_std': np.std(prevs_no_tx), + } + + print(f" With treatment: {results[p_key]['with_tx_mean']:.3f} " + f"± {results[p_key]['with_tx_std']:.3f}") + print(f" No treatment: {results[p_key]['no_tx_mean']:.3f} " + f"± {results[p_key]['no_tx_std']:.3f}") + + return results + + +def plot_calibration(results): + p_vals = [float(k) for k in results.keys()] + with_tx_means = [results[k]['with_tx_mean'] for k in results.keys()] + with_tx_stds = [results[k]['with_tx_std'] for k in results.keys()] + no_tx_means = [results[k]['no_tx_mean'] for k in results.keys()] + + fig, ax = plt.subplots(figsize=(10, 6)) + + ax.errorbar(p_vals, with_tx_means, yerr=with_tx_stds, + marker='o', capsize=4, lw=2, color='#d62728', + label='With status quo treatment') + ax.plot(p_vals, no_tx_means, + marker='s', lw=2, color='#1f77b4', alpha=0.5, + label='No treatment') + + ax.axhline(TARGET_PREV, color='black', ls='--', lw=1.5, + label=f'Target: {TARGET_PREV:.0%}') + + # Find closest match + diffs = [abs(m - TARGET_PREV) for m in with_tx_means] + best_idx = np.argmin(diffs) + best_p = p_vals[best_idx] + best_prev = with_tx_means[best_idx] + ax.axvline(best_p, color='green', ls=':', lw=1.5, + label=f'Best fit: p={best_p:.2f} → {best_prev:.3f}') + + ax.set_xlabel('p_hmb_prone') + ax.set_ylabel('Observed HMB prevalence') + ax.set_title('Calibrating p_hmb_prone with status quo treatment') + ax.legend(frameon=False) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + plt.tight_layout() + fig.savefig('hmb_prone_calibration.png', dpi=300, bbox_inches='tight') + print(f"\nBest fit: p_hmb_prone = {best_p:.2f} " + f"→ observed prevalence = {best_prev:.3f}") + return fig + + +if __name__ == '__main__': + results = run_calibration() + plot_calibration(results) \ No newline at end of file diff --git a/hiud_acceptance.py b/hiud_acceptance.py deleted file mode 100644 index d325ccd..0000000 --- a/hiud_acceptance.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Calibration sweep: find hIUD acceptance probabilities that produce -~5%, 10%, 15% of HMB women ever using hIUD. - -NSAID, TXA, and Pill acceptance all fixed at 50% or 25%. -Runs for 10%, 20%, and 35% base care-seeking. -""" - -import numpy as np -import sciris as sc -import starsim as ss -import fpsim as fp -import os -import matplotlib.pyplot as plt - -from menstruation import Menstruation -from education import Education -from interventions import HMBCascade -from analyzers import track_hmb_anemia - -# ── Settings ─────────────────────────────────────────────────────────────────── -START = 2020 -STOP = 2030 -INTV_YEAR = 2026 -N_SEEDS = 5 - -# Care-seeking scenarios -# CARE_SCENARIOS = { -# '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), -# '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), -# '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), -# } - -# anemic and pain coefficients held fixed; only base rate varies -CARE_SCENARIOS = { - '10%': sc.objdict(base=0.10, anemic=0.86, pain=0.37), - '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), - '35%': sc.objdict(base=0.35, anemic=0.86, pain=0.37), -} - -CARE_COLORS = { - '10%': '#d62728', - '20%': '#ff7f0e', - '35%': '#2196F3', -} - -# Fixed acceptance for NSAID, TXA, Pill -FIXED_ACCEPT = 0.50 - -# Sweep hIUD acceptance from 0.1 to 1.0 -HIUD_ACCEPT_VALUES = np.arange(0.1, 1.05, 0.1) - -OUTFOLDER = 'results_calibration/' -os.makedirs(OUTFOLDER, exist_ok=True) - - -def make_sim(care_behavior, hiud_accept, seed=0): - """Build sim with all treatment acceptance at 50% except hIUD which is varied.""" - mens = Menstruation() - edu = Education() - cascade = HMBCascade( - pars=dict( - year=INTV_YEAR, - time_to_assess=ss.months(3), - care_behavior=care_behavior, - nsaid=sc.objdict( - efficacy=0.5, - adherence=0.7, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=FIXED_ACCEPT), # was 0.7 - ), - txa=sc.objdict( - efficacy=0.6, - adherence=0.6, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=FIXED_ACCEPT), # was 0.6 - ), - pill=sc.objdict( - efficacy=0.7, - adherence=0.75, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=FIXED_ACCEPT), # was 0.5 (unchanged) - ), - hiud=sc.objdict( - efficacy=0.8, - adherence=0.85, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=hiud_accept), # swept - ), - ) - ) - - sim = fp.Sim( - start=START, - stop=STOP, - n_agents=10000, - total_pop=55_000_000, - location='kenya', - education_module=edu, - connectors=[mens], - interventions=[cascade], - analyzers=[track_hmb_anemia()], - rand_seed=seed, - verbose=0, - ) - return sim - - -def run_calibration(force_rerun=True): - """Sweep hIUD acceptance for each care-seeking scenario.""" - results_file = OUTFOLDER + 'hiud_calibration_accept50.obj' - - if not force_rerun and os.path.exists(results_file): - print("Loading saved calibration...") - return sc.loadobj(results_file) - - results = {} - - for care_label, care_behavior in CARE_SCENARIOS.items(): - print(f"\n{'='*60}") - print(f"Care-seeking: {care_label} | NSAID/TXA/Pill accept = {FIXED_ACCEPT}") - print(f"{'='*60}") - - results[care_label] = {} - - for accept_val in HIUD_ACCEPT_VALUES: - accept_key = f'{accept_val:.1f}' - print(f"\n hIUD accept = {accept_val:.1f}") - uptakes = [] - - for seed in range(N_SEEDS): - print(f" seed {seed}...", end=" ", flush=True) - sim = make_sim(care_behavior, hiud_accept=accept_val, seed=seed) - sim.run() - - cascade_intv = sim.interventions.hmb_cascade - menstruating = sim.people.menstruation.menstruating - hmb = sim.people.menstruation.hmb - - n_treatments = ( - np.array(cascade_intv.tried_nsaid, dtype=int) + - np.array(cascade_intv.tried_txa, dtype=int) + - np.array(cascade_intv.tried_pill, dtype=int) + - np.array(cascade_intv.tried_hiud, dtype=int) - ) - - # % of menstruating women with underlying HMB (including those on treatment) - hmb_underlying = (hmb | cascade_intv.on_any_treatment) & menstruating - hmb_menstruating = hmb_underlying - - n_hmb = np.count_nonzero(hmb_menstruating) - tried_hiud = cascade_intv.treatments['hiud'].tried_treatment & hmb_menstruating - n_hiud = np.count_nonzero(tried_hiud) - pct_hmb = 100 * n_hiud / n_hmb if n_hmb > 0 else 0 - - # % of those who tried any treatment who tried hIUD - tried_any = (n_treatments >= 1) & menstruating - n_tried_any = np.count_nonzero(tried_any) - tried_hiud_any = cascade_intv.treatments['hiud'].tried_treatment & tried_any - n_hiud_any = np.count_nonzero(tried_hiud_any) - pct_treated = 100 * n_hiud_any / n_tried_any if n_tried_any > 0 else 0 - - # % of HMB women who ever sought care (offered NSAID) - ever_offered = cascade_intv.treatments['nsaid'].offered & hmb_menstruating - n_seekers = np.count_nonzero(ever_offered) - tried_hiud_seekers = cascade_intv.treatments['hiud'].tried_treatment & ever_offered - n_hiud_seekers = np.count_nonzero(tried_hiud_seekers) - pct_seekers = 100 * n_hiud_seekers / n_seekers if n_seekers > 0 else 0 - - uptakes.append({ - 'pct_of_hmb': pct_hmb, - 'pct_of_treated': pct_treated, - 'pct_of_seekers': pct_seekers, - 'n_hmb': n_hmb, - 'n_hiud': n_hiud, - 'n_tried_any': n_tried_any, - 'n_seekers': n_seekers, - }) - print("done") - - results[care_label][accept_key] = uptakes - - sc.saveobj(results_file, results) - print(f"\nSaved: {results_file}") - return results - - -def plot_calibration(results): - """Plot calibration curves for all three care-seeking scenarios.""" - fig, axes = plt.subplots(1, 3, figsize=(20, 6)) - fig.suptitle(f'hIUD acceptance calibration (NSAID/TXA/Pill accept = {int(FIXED_ACCEPT*100)}%)', - fontsize=14) - - panels = [ - ('pct_of_hmb', '% of HMB women\nwho ever tried hIUD', axes[0]), - ('pct_of_seekers', '% of HMB care-seekers\nwho tried hIUD', axes[1]), - ('pct_of_treated', '% of treated women\nwho tried hIUD', axes[2]), - ] - - for metric_key, ylabel, ax in panels: - for care_label in CARE_SCENARIOS: - accept_vals = [] - means = [] - stds = [] - - for accept_key in sorted(results[care_label].keys(), key=float): - accept_vals.append(float(accept_key)) - vals = [u[metric_key] for u in results[care_label][accept_key]] - means.append(np.mean(vals)) - stds.append(np.std(vals)) - - ax.errorbar(accept_vals, means, yerr=stds, - marker='o', capsize=4, lw=2, - color=CARE_COLORS[care_label], - label=f'Base {care_label}') - - # Target lines - for target, color, ls in [(5, '#2196F3', ':'), (10, '#4CAF50', '--'), (15, '#F44336', '-.')]: - ax.axhline(target, ls=ls, color=color, lw=1, alpha=0.5, - label=f'Target: {target}%') - - ax.set_xlabel('hIUD acceptance probability') - ax.set_ylabel(ylabel) - ax.legend(frameon=False, fontsize=8) - ax.grid(alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - plt.tight_layout() - outpath = OUTFOLDER + 'hiud_calibration_accept50.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - - # Print lookup tables - for care_label in CARE_SCENARIOS: - print(f"\n{'─'*70}") - print(f" Care-seeking: {care_label} | NSAID/TXA/Pill accept = {int(FIXED_ACCEPT*100)}%") - print(f"{'─'*70}") - print(f" {'Accept':>8} {'% of HMB':>10} {'% of seekers':>14} {'% of treated':>14}") - print(f"{'─'*70}") - - for accept_key in sorted(results[care_label].keys(), key=float): - vals_hmb = [u['pct_of_hmb'] for u in results[care_label][accept_key]] - vals_seek = [u['pct_of_seekers'] for u in results[care_label][accept_key]] - vals_treat = [u['pct_of_treated'] for u in results[care_label][accept_key]] - print(f" {float(accept_key):>8.1f} " - f"{np.mean(vals_hmb):>9.1f}% " - f"{np.mean(vals_seek):>13.1f}% " - f"{np.mean(vals_treat):>13.1f}%") - - print(f"{'─'*70}") - - return fig - - -if __name__ == '__main__': - results = run_calibration(force_rerun=True) - plot_calibration(results) \ No newline at end of file diff --git a/interventions.py b/interventions.py index 054f6b7..8f9790e 100644 --- a/interventions.py +++ b/interventions.py @@ -422,7 +422,7 @@ def __init__(self, pars=None, eligibility=None, **kwargs): care_seeking_dist = ss.normal(1, 1), # Treatment parameters - efficacy=0.5, # 50% responder rate + efficacy=0.33, # 33% responder rate adherence=0.7, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.7), @@ -480,7 +480,7 @@ def __init__(self, pars=None, eligibility=None, **kwargs): care_seeking_dist = ss.normal(1, 1), # Treatment parameters - efficacy=0.6, # 60% responder rate + efficacy=0.45, # 45% responder rate adherence=0.6, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.6), @@ -537,7 +537,7 @@ def __init__(self, pars=None, eligibility=None, **kwargs): ), care_seeking_dist = ss.normal(1, 1), - efficacy=0.7, + efficacy=0.59, adherence=0.75, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.5), @@ -604,7 +604,7 @@ def __init__(self, pars=None, eligibility=None, **kwargs): ), care_seeking_dist = ss.normal(1, 1), - efficacy=0.8, + efficacy=0.88, adherence=0.85, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.5), @@ -691,25 +691,25 @@ def __init__(self, pars=None, **kwargs): # Treatment-specific parameters nsaid=sc.objdict( - efficacy=0.5, + efficacy=0.33, #from HIUD HMB Clinical Value Prop FINAL 110125, matching values Lauren used in her model adherence=0.7, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.7), ), txa=sc.objdict( - efficacy=0.6, + efficacy=0.45, # adherence=0.6, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.6), ), pill=sc.objdict( - efficacy=0.7, + efficacy=0.59, adherence=0.75, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.5), ), hiud=sc.objdict( - efficacy=0.8, + efficacy=0.88, adherence=0.85, prob_offer=ss.bernoulli(p=0.9), prob_accept=ss.bernoulli(p=0.5), diff --git a/interventions_pool.py b/interventions_pool.py new file mode 100644 index 0000000..bf60a4e --- /dev/null +++ b/interventions_pool.py @@ -0,0 +1,1549 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Apr 23 09:49:40 2026 + +@author: navidehno + +============================================================================= +FILE OVERVIEW +============================================================================= + +This single file contains everything needed for the HMB treatment model: + + 1. HMBTreatmentBase — shared logic for all four treatments + 2. NSAIDTreatment, TXATreatment, PillTreatment, hIUDTreatment + 3. HMBPool — pool-based treatment assignment + 4. HMBCounterfactual — status-quo pool (no hIUD, runs from 2020) + 5. HMBCascade — sequential cascade (optional, kept for comparison) + 6. Factory functions — convenience sim builders + +============================================================================= +CARE-SEEKING: +============================================================================= + +Layer 1 — Ever-seeker (one-time draw at HMB onset): + Bernoulli(p_ever_seek). Women who get False NEVER enter the care + system. This controls the "% who will ever seek care" (10–35%). + +Layer 2 — Monthly arrival (memoryless Bernoulli each month): + Among ever-seekers, a fresh draw each eligible month determines + whether she shows up at a clinic. Different rates for first vs + repeat visits, and condition-specific rates for anemia/pain: + + First visit: + Base: p_monthly_first (default 1/36 ≈ 3yr lag) + With pain: p_monthly_first_pain (default 1/24 ≈ 2yr lag) + With anemia: p_monthly_first_anemic (default 1/12 ≈ 1yr lag) + Both anemia+pain: takes the faster rate (anemia) + + Repeat visits: p_monthly_repeat (default 1/6 ≈ 6mo lag) + Always the same regardless of condition. + +Layer 3 — Repeat cap (max 3 care-seeking episodes): + After 3 episodes, the next re-entry attempt routes: + 90% → give up (permanently stop seeking) + 10% → hysterectomy (HMB removed, permanently stop seeking) + +============================================================================= +ADHERENCE MODEL +============================================================================= + +Adherence is a RETENTION mechanism, not a responder-rate modifier. +The chain is: + efficacy → responder draw (at treatment start, unchanged) + adherence → determines if she keeps taking it each month + continuation → for use-at-will treatments, voluntary stop based on HMB + +Behavior by treatment type: + Use-at-will (NSAID, TXA): + Non-adherent woman has 10% monthly probability of discontinuing. + She may remain non-adherent for months before actually stopping. + Continuous (Pill, hIUD): + Non-adherent → immediate cessation next timestep. + (Missing a daily pill or removing IUD ends the treatment.) + +Adherence values (from Darcy): + NSAID: 80% | TXA: 70% | Pill: 80% | hIUD: 100% + +============================================================================= +POOL vs CASCADE +============================================================================= + +Pool (primary): + At each care-seeking visit, one treatment is randomly sampled from a + weighted distribution (e.g., 50% NSAID, 25% TXA, 25% Pill). + Women who tried-and-failed a treatment cannot be re-assigned it. + Women who tried-and-it-worked CAN be re-assigned it. + +Cascade (optional, for comparison): + Sequential: NSAID → TXA → Pill → hIUD. + Each tier requires failure/refusal of the previous tier. + +============================================================================= +""" + +import numpy as np +import starsim as ss +import sciris as sc + + +# ============================================================================ +# SECTION 1: BASE TREATMENT CLASS +# ============================================================================ +# This is the foundation that all four treatments inherit from. +# It provides: +# - State definitions (seeking_care, on_treatment, responder, etc.) +# - Three-layer care-seeking logic +# - Treatment effectiveness assessment +# - Adherence checking (retention model) +# - Continuation logic (for use-at-will treatments) +# - Two-phase stepping (for use inside orchestrators) +# ============================================================================ + +class HMBTreatmentBase(ss.Intervention): + """ + Base class for all HMB treatments. + + Not used directly — subclassed by NSAIDTreatment, TXATreatment, etc. + """ + + def __init__(self, name, pars=None, eligibility=None, **kwargs): + super().__init__(name=name, eligibility=eligibility) + + # ── Starsim distributions (reused each timestep) ── + self._p_care = ss.bernoulli(p=0) # Layer 2: monthly arrival draw + self._p_accept = ss.bernoulli(p=0) + self._p_continue = ss.bernoulli(p=0) # Continuation check + self._p_ever_seek = ss.bernoulli(p=0) # Layer 1: ever-seeker draw + self._p_adherent = ss.bernoulli(p=0) # Adherence check + self._p_discontinue = ss.bernoulli(p=0) # Discontinuation for non-adherent + + # ── Default parameters (overridden by subclasses) ── + self.define_pars( + use_at_will=False, + p_discontinue_nonadherent=0.10, # 10% monthly chance of stopping when non-adherent + ) + self.update_pars(pars, **kwargs) + + return + + # ------------------------------------------------------------------ + # State definitions + # ------------------------------------------------------------------ + + def _define_common_states(self): + """ + Define states shared by all treatments. + + Called in each subclass __init__() AFTER define_pars(), so that + self.pars.efficacy is available for the responder Bernoulli. + + Note: responder rate = efficacy ONLY (adherence is separate). + """ + self.define_states( + # ── Care seeking ── + ss.BoolState('seeking_care'), # Is she seeking care THIS month? + ss.BoolState('ever_sought_care'), # Has she EVER sought care? + ss.BoolState('ever_seeker'), # Layer 1: will she ever seek care? + ss.BoolState('ever_seeker_assigned'), # Has the Layer 1 draw been made? + + # ── Treatment status ── + ss.BoolState('on_treatment'), # Currently receiving this treatment? + ss.BoolState('tried_treatment'), # Has she ever tried this treatment? + ss.BoolState('offered'), # Was this treatment ever offered? + ss.BoolState('accepted'), # Did she ever accept this treatment? + ss.FloatArr('accept_propensity', default=ss.uniform()), # Fixed U(0,1) for acceptance + + # ── Treatment response ── + # Responder rate = efficacy only. Adherence is handled separately. + ss.BoolState('responder', default=ss.bernoulli(p=self.pars.efficacy)), + ss.BoolState('treatment_effective'), # Did treatment resolve HMB? + ss.BoolState('treatment_assessed'), # Has 3-month assessment happened? + ss.BoolState('was_effective', default=False), # Ever effective (persists after stopping) + + # ── Adherence ── + ss.BoolState('adherent'), # Is she adherent this month? + ss.FloatArr('ti_nonadherent'), # When did non-adherence start? + + # ── Timing ── + ss.FloatArr('ti_start_treatment'), + ss.FloatArr('dur_treatment'), + ss.FloatArr('ti_stop_treatment'), + ) + + # Use-at-will treatments track HMB status for continuation decisions + if self.pars.use_at_will: + self.define_states( + ss.FloatArr('hmb_last_cycle', default=float('nan')), + ) + + # ------------------------------------------------------------------ + # Helper: find the orchestrator (pool, counterfactual, or cascade) + # ------------------------------------------------------------------ + + def _get_hmb_intervention(self): + """ + Return the orchestrator intervention (pool/counterfactual/cascade) + that this treatment belongs to, or None if standalone. + + This lets the treatment class reference shared states (ever_seeker, + on_any_treatment, gave_up, had_hysterectomy) without hardcoding + which orchestrator type it's inside. + """ + for name in ['hmb_pool', 'hmb_counterfactual', 'hmb_cascade']: + if hasattr(self.sim.interventions, name): + return getattr(self.sim.interventions, name) + return None + + # ------------------------------------------------------------------ + # Layer 1: Ever-seeker assignment (standalone mode only) + # ------------------------------------------------------------------ + + def _assign_ever_seekers_standalone(self): + """ + Assign ever_seeker for newly-HMB women when running standalone. + + When this treatment is inside an orchestrator (pool/cascade), + the orchestrator handles assignment centrally. This method only + runs for standalone single-treatment sims. + """ + if self._get_hmb_intervention() is not None: + return # Orchestrator handles it + + ppl = self.sim.people + new_hmb = (ppl.menstruation.hmb & + ppl.menstruation.menstruating & + ~self.ever_seeker_assigned).uids + + if len(new_hmb) == 0: + return + + self.ever_seeker_assigned[new_hmb] = True + self._p_ever_seek.set(self.pars.care_behavior.p_ever_seek) + assigned_true = self._p_ever_seek.filter(new_hmb) + self.ever_seeker[assigned_true] = True + + # ------------------------------------------------------------------ + # Layer 2: Monthly arrival (memoryless process) + # ------------------------------------------------------------------ + + def determine_care_seeking(self, uids=None): + """ + Determine who seeks care this month. + + Layer 1 (ever-seeker) is already assigned before this runs. + This method implements Layer 2: among eligible ever-seekers, + draw Bernoulli(p) where p depends on: + - First visit vs repeat visit + - Anemia/pain status (first visits only) + + First-visit rates (condition-specific): + Base: p_monthly_first (default 1/36) + Pain: p_monthly_first_pain (default 1/24) + Anemia: p_monthly_first_anemic (default 1/12) + Both: takes the fastest (anemia) + + Repeat-visit rate: + Always p_monthly_repeat (default 1/6), regardless of condition. + """ + self.seeking_care[:] = False + + if uids is None: + ppl = self.sim.people + orch = self._get_hmb_intervention() + + # Use orchestrator-level states if available + if orch is not None: + on_any = orch.on_any_treatment + is_ever_seeker = orch.ever_seeker + else: + on_any = self.on_treatment + is_ever_seeker = self.ever_seeker + + # Build eligibility mask + eligible = (is_ever_seeker & + ppl.menstruation.hmb & + ppl.menstruation.menstruating & + ~ppl.fp.pregnant & + ~ppl.fp.postpartum & + ~on_any) + + # Exclude women who gave up or had hysterectomy + if orch is not None: + eligible = eligible & ~orch.gave_up & ~orch.had_hysterectomy + + uids = eligible.uids + + if len(uids) == 0: + return + + # ── Compute per-person monthly probability ── + cb = self.pars.care_behavior + + # Start with base first-visit rate for everyone + p_monthly = np.full(len(uids), cb.p_monthly_first, dtype=float) + + # Determine who is a first-time seeker vs repeat + ever_sought = np.asarray(self.ever_sought_care[uids]) + first_time = ~ever_sought + + # For first-time seekers: apply condition-specific rates + # Pain gives a faster rate than base + pain_mask = first_time & np.asarray(self.sim.people.menstruation.pain[uids]) + p_monthly[pain_mask] = np.maximum( + p_monthly[pain_mask], cb.p_monthly_first_pain) + + # Anemia gives an even faster rate (applied second so it wins over pain) + anemic_mask = first_time & np.asarray(self.sim.people.menstruation.anemic[uids]) + p_monthly[anemic_mask] = np.maximum( + p_monthly[anemic_mask], cb.p_monthly_first_anemic) + + # Repeat seekers: flat rate regardless of condition + p_monthly[~first_time] = cb.p_monthly_repeat + + # Cap at 1.0 + np.minimum(p_monthly, 1.0, out=p_monthly) + + # ── Fresh Bernoulli draw (memoryless) ── + self._p_care.set(p_monthly) + seeks_care_uids = self._p_care.filter(uids) + + self.seeking_care[seeks_care_uids] = True + self.ever_sought_care[seeks_care_uids] = True + + return + + # ------------------------------------------------------------------ + # Treatment effectiveness assessment + # ------------------------------------------------------------------ + + def assess_treatment_effectiveness(self): + """ + Assess whether treatment resolved HMB after time_to_assess (3 months). + + For responders currently on treatment: set HMB = False. + After 3 months, check if HMB is still resolved: + - If resolved: mark treatment_effective and was_effective + - If not: stop treatment (she re-enters care-seeking pool) + """ + # Apply treatment effect for responders + on_treatment_responders = self.on_treatment & self.responder + self.sim.people.menstruation.hmb[on_treatment_responders.uids] = False + + # Find those ready to assess + on_treatment_uids = (self.on_treatment & ~self.treatment_assessed).uids + if len(on_treatment_uids) == 0: + return + + time_on = self.ti - self.ti_start_treatment[on_treatment_uids] + ready = on_treatment_uids[time_on >= self.pars.time_to_assess] + self.treatment_assessed[ready] = True + + if len(ready) == 0: + return + + hmb = self.sim.people.menstruation.hmb + has_hmb = ready & hmb + no_hmb = ready & ~hmb + + self.treatment_effective[no_hmb] = True + self.was_effective[no_hmb] = True + self.treatment_effective[has_hmb] = False + self.ti_stop_treatment[has_hmb] = self.ti + 1 # Stop next timestep + + # ------------------------------------------------------------------ + # Adherence checking (retention model) + # ------------------------------------------------------------------ + + def check_adherence(self): + """ + Check adherence for women currently on treatment. + + Each month, each woman on treatment draws Bernoulli(adherence). + If non-adherent: + Use-at-will (NSAID/TXA): + 10% monthly probability of actually discontinuing. + She may stay non-adherent for months before stopping. + Continuous (Pill/hIUD): + Immediate cessation — stops next timestep. + + This is SEPARATE from efficacy. A responder who is non-adherent + may stop a treatment that was working for her. + """ + on_treatment_uids = self.on_treatment.uids + if len(on_treatment_uids) == 0: + return + + # Draw adherence for everyone on treatment + adherence_rate = self.pars.get('adherence', 1.0) + self._p_adherent.set(adherence_rate) + is_adherent = self._p_adherent.filter(on_treatment_uids) + self.adherent[:] = False + self.adherent[is_adherent] = True + + # Get non-adherent individuals + nonadherent = on_treatment_uids - is_adherent + if len(nonadherent) == 0: + return + + if self.pars.use_at_will: + # Gradual: 10% monthly probability of discontinuing + self._p_discontinue.set(self.pars.p_discontinue_nonadherent) + discontinue = self._p_discontinue.filter(nonadherent) + if len(discontinue) > 0: + self.ti_stop_treatment[discontinue] = self.ti + 1 + else: + # Immediate: stop next timestep + self.ti_stop_treatment[nonadherent] = self.ti + 1 + + # ------------------------------------------------------------------ + # Stop treatment + # ------------------------------------------------------------------ + + def stop_treatment(self): + """Stop treatment for women whose ti_stop_treatment == current time.""" + stoppers = (self.ti_stop_treatment == self.ti).uids + if len(stoppers) == 0: + return + + self.on_treatment[stoppers] = False + self.dur_treatment[stoppers] = np.nan + self.treatment_effective[stoppers] = False + self.treatment_assessed[stoppers] = False + self.adherent[stoppers] = False + self.ti_nonadherent[stoppers] = np.nan + + # ------------------------------------------------------------------ + # Continuation (use-at-will only) + # ------------------------------------------------------------------ + + def check_continuation(self): + """ + Cycle-by-cycle voluntary continuation for use-at-will treatments. + + Each month, women on NSAID/TXA decide whether to continue: + First cycle (no history): 60% continue + HMB resolved last cycle: 90% continue + HMB persisted last cycle: 20% continue + + Not used for Pill/hIUD — those follow FPsim duration. + """ + if not self.pars.use_at_will: + return + + on_treatment_uids = self.on_treatment.uids + if len(on_treatment_uids) == 0: + return + + hmb_this = self.sim.people.menstruation.hmb[on_treatment_uids] + last = self.hmb_last_cycle[on_treatment_uids] + + p_continue = np.where( + np.isnan(last), + self.pars.p_continue_first_cycle, + np.where(last == 1.0, + self.pars.p_continue_if_resolved, + self.pars.p_continue_if_persists)) + + self._p_continue.set(p_continue) + continuers = self._p_continue.filter(on_treatment_uids) + stops = on_treatment_uids - continuers + + self.hmb_last_cycle[on_treatment_uids] = (~hmb_this).astype(float) + + if len(stops) > 0: + self.ti_stop_treatment[stops] = self.ti + 1 + self.hmb_last_cycle[stops] = np.nan + + # ------------------------------------------------------------------ + # Convenience properties + # ------------------------------------------------------------------ + + @property + def anemic(self): + return self.sim.people.menstruation.anemic + + @property + def pain(self): + return self.sim.people.menstruation.pain + + # ------------------------------------------------------------------ + # Two-phase stepping (used by orchestrators) + # ------------------------------------------------------------------ + + def step_seek(self): + """Phase 1: housekeeping + care-seeking. Called by orchestrator.""" + if self.sim.t.now() < self.pars.year: + return + self._pre_step_hook() + self.stop_treatment() + self.determine_care_seeking() + + def step_treat(self): + """Phase 2: offer/assess/adherence/continuation. Called by orchestrator.""" + if self.sim.t.now() < self.pars.year: + return + self.offer_treatment() + self.assess_treatment_effectiveness() + self.check_adherence() + self.check_continuation() + + def step(self): + """Full step for standalone use (not inside orchestrator).""" + if self.sim.t.now() < self.pars.year: + return + self._pre_step_hook() + self.stop_treatment() + self._assign_ever_seekers_standalone() + self.determine_care_seeking() + self.offer_treatment() + self.assess_treatment_effectiveness() + self.check_adherence() + self.check_continuation() + + def _pre_step_hook(self): + """Override in subclasses for pre-step setup (e.g., hIUD set_states).""" + pass + + # ------------------------------------------------------------------ + # Treatment offering (used by cascade; pool overrides this) + # ------------------------------------------------------------------ + + def offer_treatment(self): + """ + Offer this treatment to eligible care seekers. + + Used by the cascade model. The pool model bypasses this entirely + and uses its own _assign_treatments() instead. + """ + care_seekers = self.seeking_care.uids + if len(care_seekers) == 0: + return + + can_try = self._get_eligible_for_treatment(care_seekers) + + offered = self.pars.prob_offer.filter(can_try) + if len(offered) > 0: + self.offered[offered] = True + + threshold = self.pars.prob_accept.pars['p'] + accepts = self.accept_propensity[offered] < threshold + accepted = offered[accepts] + + if len(accepted) > 0: + self.accepted[accepted] = True + self.tried_treatment[accepted] = True + self._start_treatment(accepted) + + def _get_eligible_for_treatment(self, care_seekers): + """ + Eligibility: seeking care AND (never tried OR was effective) AND not on treatment. + """ + return care_seekers & (~self.tried_treatment | self.was_effective) & ~self.on_treatment + + +# ============================================================================ +# SECTION 2: INDIVIDUAL TREATMENT CLASSES +# ============================================================================ +# Each subclass sets treatment-specific defaults and implements +# _start_treatment(). Pill and hIUD also integrate with FPsim. +# ============================================================================ + +class NSAIDTreatment(HMBTreatmentBase): + """ + NSAID treatment for HMB. + + Prostaglandin inhibition. Taken during menstruation only (use-at-will). + Efficacy: 33% responder rate. Adherence: 80%. + """ + + def __init__(self, pars=None, eligibility=None, **kwargs): + super().__init__(name='nsaid_treatment', eligibility=eligibility) + + self.define_pars( + year=2020, + + # Care-seeking (defaults; orchestrator overrides these) + care_behavior=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + + # Treatment + efficacy=0.33, + adherence=0.80, + prob_offer=ss.bernoulli(p=0.7), + prob_accept=ss.bernoulli(p=0.5), + + # Timing + time_to_assess=ss.months(3), + dur_treatment=ss.uniform(ss.months(10), ss.months(14)), + + # Use-at-will + use_at_will=True, + p_discontinue_nonadherent=0.10, + + # Continuation + p_continue_first_cycle=0.6, + p_continue_if_resolved=0.9, + p_continue_if_persists=0.2, + ) + self.update_pars(pars, **kwargs) + self._define_common_states() + + def _start_treatment(self, uids): + self.on_treatment[uids] = True + self.ti_start_treatment[uids] = self.ti + self.treatment_assessed[uids] = False + self.dur_treatment[uids] = self.pars.dur_treatment.rvs(uids) + self.ti_stop_treatment[uids] = self.ti + self.dur_treatment[uids] + + +class TXATreatment(HMBTreatmentBase): + """ + Tranexamic acid (TXA) treatment for HMB. + + Antifibrinolytic. Taken during menstruation only (use-at-will). + Efficacy: 45% responder rate. Adherence: 70% (higher pill burden). + """ + + def __init__(self, pars=None, eligibility=None, **kwargs): + super().__init__(name='txa_treatment', eligibility=eligibility) + + self.define_pars( + year=2020, + care_behavior=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + efficacy=0.45, + adherence=0.70, + prob_offer=ss.bernoulli(p=0.7), + prob_accept=ss.bernoulli(p=0.5), + time_to_assess=ss.months(3), + dur_treatment=ss.uniform(ss.months(10), ss.months(14)), + use_at_will=True, + p_discontinue_nonadherent=0.10, + p_continue_first_cycle=0.6, + p_continue_if_resolved=0.9, + p_continue_if_persists=0.2, + ) + self.update_pars(pars, **kwargs) + self._define_common_states() + + def _start_treatment(self, uids): + self.on_treatment[uids] = True + self.ti_start_treatment[uids] = self.ti + self.treatment_assessed[uids] = False + self.dur_treatment[uids] = self.pars.dur_treatment.rvs(uids) + self.ti_stop_treatment[uids] = self.ti + self.dur_treatment[uids] + + +class PillTreatment(HMBTreatmentBase): + """ + Combined oral contraceptive pill for HMB. + + Hormonal. Continuous use (not use-at-will). Integrates with FPsim. + Efficacy: 59% responder rate. Adherence: 80%. + Requires no fertility intent (it's a contraceptive). + """ + + def __init__(self, pars=None, eligibility=None, **kwargs): + super().__init__(name='pill_treatment', eligibility=eligibility) + + self.define_pars( + year=2020, + care_behavior=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + efficacy=0.59, + adherence=0.80, + prob_offer=ss.bernoulli(p=0.7), + prob_accept=ss.bernoulli(p=0.5), + time_to_assess=ss.months(3), + # use_at_will defaults to False (continuous) + ) + self.update_pars(pars, **kwargs) + self._define_common_states() + + @property + def pill_idx(self): + return self.sim.connectors.contraception.get_method_by_label('Pill').idx + + def _get_eligible_for_treatment(self, care_seekers): + """Pill requires no fertility intent.""" + base = super()._get_eligible_for_treatment(care_seekers) + return base & ~self.sim.people.fp.fertility_intent + + def _start_treatment(self, uids): + self.on_treatment[uids] = True + self.ti_start_treatment[uids] = self.ti + self.treatment_assessed[uids] = False + + # Set as contraceptive method via FPsim + self.sim.people.fp.method[uids] = self.pill_idx + self.sim.people.fp.on_contra[uids] = True + self.sim.people.fp.ever_used_contra[uids] = True + + # Get duration from FPsim + method_dur = self.sim.connectors.contraception.set_dur_method(uids) + self.sim.people.fp.ti_contra[uids] = self.ti + method_dur + self.dur_treatment[uids] = method_dur + self.ti_stop_treatment[uids] = self.ti + method_dur + + +class hIUDTreatment(HMBTreatmentBase): + """ + Hormonal IUD treatment for HMB. + + Long-acting hormonal. Continuous. Integrates with FPsim. + Efficacy: 88% responder rate. Adherence: 100% (until removal). + Requires no fertility intent. + """ + + def __init__(self, pars=None, eligibility=None, **kwargs): + super().__init__(name='hiud_treatment', eligibility=eligibility) + + self.define_pars( + year=2020, + care_behavior=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + efficacy=0.88, + adherence=1.00, + prob_offer=ss.bernoulli(p=0.7), + prob_accept=ss.bernoulli(p=0.5), + time_to_assess=ss.months(3), + p_hiud=ss.bernoulli(p=0.17), + ) + self.update_pars(pars, **kwargs) + self._define_common_states() + + self.define_states( + ss.BoolState('hiud_prone', label="Prone to use hormonal IUD"), + ) + + @property + def iud_idx(self): + return self.sim.connectors.contraception.get_method_by_label('IUDs').idx + + def set_states(self): + uids = ss.uids(self.hiud_prone.isnan) + self.hiud_prone[uids] = self.pars.p_hiud.rvs(uids) + + def _pre_step_hook(self): + self.set_states() + + def _get_eligible_for_treatment(self, care_seekers): + """hIUD requires no fertility intent.""" + base = super()._get_eligible_for_treatment(care_seekers) + return base & ~self.sim.people.fp.fertility_intent + + def _start_treatment(self, uids): + self.on_treatment[uids] = True + self.ti_start_treatment[uids] = self.ti + self.treatment_assessed[uids] = False + + self.sim.people.fp.method[uids] = self.iud_idx + self.sim.people.fp.on_contra[uids] = True + self.sim.people.fp.ever_used_contra[uids] = True + + method_dur = self.sim.connectors.contraception.set_dur_method(uids) + self.sim.people.fp.ti_contra[uids] = self.ti + method_dur + self.dur_treatment[uids] = method_dur + self.ti_stop_treatment[uids] = self.ti + method_dur + + +# ============================================================================ +# SECTION 3: POOL-BASED TREATMENT ASSIGNMENT (PRIMARY MODEL) +# ============================================================================ +# When a woman seeks care, she draws from a weighted distribution to get +# one treatment. This replaces the sequential cascade as the default. +# +# Supports pre/post treatment weights for status quo → intervention shift. +# ============================================================================ + +class HMBPool(ss.Intervention): + """ + Pool-based HMB treatment assignment with three-layer care-seeking. + + At each care-seeking visit, one treatment is randomly sampled from + a weighted distribution. Failed treatments are blocked from re-draw. + + Supports two sets of weights (tx_weights_pre / tx_weights_post) to + model an immediate shift at intv_year (e.g., introducing hIUD in 2026). + """ + + def __init__(self, pars=None, **kwargs): + super().__init__(name='hmb_pool', **kwargs) + + self.define_pars( + year=2020, # When this intervention starts running + intv_year=2026, # When tx_weights shift from pre to post + time_to_assess=ss.months(3), + + # Treatment weights BEFORE intervention year (status quo) + # These are relative weights among treated women. + # "Not treated" is handled by prob_offer (70% = 30% not treated). + tx_weights_pre=sc.objdict( + nsaid=0.50, + txa=0.25, + pill=0.25, + hiud=0.0, # No hIUD in status quo + ), + + # Treatment weights AFTER intervention year + # Override this for different hIUD scenarios. + tx_weights_post=sc.objdict( + nsaid=0.50, + txa=0.25, + pill=0.25, + hiud=0.0, # Override to introduce hIUD + ), + + # Per-treatment parameters + nsaid=sc.objdict( + efficacy=0.33, + adherence=0.80, + ), + txa=sc.objdict( + efficacy=0.45, + adherence=0.70, + ), + pill=sc.objdict( + efficacy=0.59, + adherence=0.80, + ), + hiud=sc.objdict( + efficacy=0.88, + adherence=1.00, + ), + + prob_offer_pre=0.70, # Status quo: 30% not treated + prob_offer_post=0.70, # Override per scenario + + # Shared acceptance + prob_accept=1.00, + + # Care-seeking BEFORE intervention year (status quo) + care_behavior_pre=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + + # Care-seeking AFTER intervention year (demand creation) + care_behavior_post=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + + # Repeat cap (Layer 3) + max_care_episodes=3, + p_hysterectomy_at_cap=0.10, + ) + self.update_pars(pars, **kwargs) + + # ── Orchestrator-level states ── + self.define_states( + # Layer 1 + ss.BoolState('ever_seeker'), + ss.BoolState('ever_seeker_assigned'), + # Layer 3 + ss.FloatArr('care_episodes', default=0), + ss.BoolState('gave_up'), + ss.BoolState('had_hysterectomy'), + ) + + # Distributions + self._p_ever_seek = ss.bernoulli(p=0) + self._p_hysterectomy = ss.bernoulli(p=0) + self._p_offer = ss.bernoulli(p=0) + + self.treatments = {} + return + + # ------------------------------------------------------------------ init + + def init_pre(self, sim): + super().init_pre(sim) + + shared = dict( + year=self.pars.year, + time_to_assess=self.pars.time_to_assess, + care_behavior=self.pars.care_behavior_pre, # Start with pre-intervention + ) + accept_p = self.pars.prob_accept + + self.treatments['nsaid'] = NSAIDTreatment(pars=dict( + **shared, + efficacy=self.pars.nsaid.efficacy, + adherence=self.pars.nsaid.adherence, + prob_accept=ss.bernoulli(p=accept_p), + )) + self.treatments['txa'] = TXATreatment(pars=dict( + **shared, + efficacy=self.pars.txa.efficacy, + adherence=self.pars.txa.adherence, + prob_accept=ss.bernoulli(p=accept_p), + )) + self.treatments['pill'] = PillTreatment(pars=dict( + **shared, + efficacy=self.pars.pill.efficacy, + adherence=self.pars.pill.adherence, + prob_accept=ss.bernoulli(p=accept_p), + )) + self.treatments['hiud'] = hIUDTreatment(pars=dict( + **shared, + efficacy=self.pars.hiud.efficacy, + adherence=self.pars.hiud.adherence, + prob_accept=ss.bernoulli(p=accept_p), + )) + + for tx in self.treatments.values(): + tx.init_pre(sim) + + # Pre-compute weight arrays + self._tx_names = list(self.treatments.keys()) + self._n_tx = len(self._tx_names) + self._tx_probs_pre = self._normalize_weights(self.pars.tx_weights_pre) + self._tx_probs_post = self._normalize_weights(self.pars.tx_weights_post) + + def _normalize_weights(self, weights): + """Normalize weight dict to probability array. Zero-sum returns uniform.""" + raw = np.array([weights[k] for k in self._tx_names], dtype=float) + total = raw.sum() + if total == 0: + # All weights zero — no treatment can be offered + return np.zeros(self._n_tx) + return raw / total + + def init_post(self): + super().init_post() + for tx in self.treatments.values(): + tx.init_post() + # For episode detection (Layer 3) + self._prev_any_seeking_uids = set() + + # ------------------------------------------------------------------ step + + def step(self): + """ + One timestep of the pool model. + + Order: + 1. Pre-step hooks (e.g., hIUD set_states) + 2. Stop expired treatments + 3. Layer 1: assign ever-seekers to new HMB women + 4. Layer 2: centralised care-seeking (one determination, shared) + 5. Layer 3: enforce episode cap + 6. Random treatment assignment from pool + 7. Assess treatment effectiveness + 8. Check adherence + 9. Continuation checks (use-at-will) + """ + if self.sim.t.now() < self.pars.year: + return + + # 0. Select time-appropriate care behavior and push to child treatments + if self.sim.t.now() >= self.pars.intv_year: + active_cb = self.pars.care_behavior_post + else: + active_cb = self.pars.care_behavior_pre + + for tx in self.treatments.values(): + tx.pars.care_behavior = active_cb + + # 1. Pre-step hooks + for tx in self.treatments.values(): + tx._pre_step_hook() + + # 2. Stop expired treatments + for tx in self.treatments.values(): + tx.stop_treatment() + + # 3. Layer 1: ever-seeker assignment + self._assign_ever_seekers() + + # 4. Layer 2: centralised care-seeking + # Use NSAID's logic as the shared determination, then copy to all. + ref_tx = self.treatments['nsaid'] + ref_tx.determine_care_seeking() + seeking_uids = ref_tx.seeking_care.uids + + for name, tx in self.treatments.items(): + if name != 'nsaid': + tx.seeking_care[:] = False + tx.seeking_care[seeking_uids] = True + tx.ever_sought_care[seeking_uids] = True + + # 5. Layer 3: episode cap + self._enforce_episode_cap() + + # 6. Treatment assignment (re-read seeking after cap zeroed some out) + seeking_uids = ref_tx.seeking_care.uids + self._assign_treatments(seeking_uids) + + # 7–9. Assess, adherence, continuation + for tx in self.treatments.values(): + tx.assess_treatment_effectiveness() + for tx in self.treatments.values(): + tx.check_adherence() + for tx in self.treatments.values(): + tx.check_continuation() + + # ------------------------------------------------ Layer 1: ever-seeker + + def _assign_ever_seekers(self): + """ + One-time draw for new HMB women: will she ever seek care? + + Bernoulli(p_ever_seek). Women who get False never enter care. + Runs each step to catch new HMB cases (onset, births aging in). + """ + ppl = self.sim.people + new_hmb = (ppl.menstruation.hmb & + ppl.menstruation.menstruating & + ~self.ever_seeker_assigned).uids + + if len(new_hmb) == 0: + return + + self.ever_seeker_assigned[new_hmb] = True + + if self.sim.t.now() >= self.pars.intv_year: + p_es = self.pars.care_behavior_post.p_ever_seek + else: + p_es = self.pars.care_behavior_pre.p_ever_seek + + self._p_ever_seek.set(p_es) + + assigned_true = self._p_ever_seek.filter(new_hmb) + self.ever_seeker[assigned_true] = True + + # ------------------------------------------------ Layer 3: episode cap + + def _get_anyone_seeking_uids(self): + """Set of UIDs seeking care in any treatment.""" + uids = set() + all_uids = np.asarray(self.sim.people.uid) + for tx in self.treatments.values(): + seeking = np.asarray(tx.seeking_care) + uids.update(all_uids[seeking]) + return uids + + def _enforce_episode_cap(self): + """ + Detect new care-seeking episodes and enforce repeat cap. + + New episode = transition from not-seeking to seeking. + After max_care_episodes: + 90% give up permanently + 10% get hysterectomy (HMB removed permanently) + """ + now_seeking = self._get_anyone_seeking_uids() + + newly_seeking = np.array( + list(now_seeking - self._prev_any_seeking_uids), dtype=int + ) + self._prev_any_seeking_uids = set(now_seeking) + + if len(newly_seeking) == 0: + return + + self.care_episodes[newly_seeking] += 1 + + over_cap = newly_seeking[ + self.care_episodes[newly_seeking] > self.pars.max_care_episodes + ] + if len(over_cap) == 0: + return + + # Route: hysterectomy vs give up + self._p_hysterectomy.set(self.pars.p_hysterectomy_at_cap) + hyst_uids = self._p_hysterectomy.filter(over_cap) + giveup_uids = np.setdiff1d(over_cap, hyst_uids) + + self.had_hysterectomy[hyst_uids] = True + self.gave_up[giveup_uids] = True + + # Zero out seeking so they don't enter treatment + for tx in self.treatments.values(): + tx.seeking_care[over_cap] = False + self._prev_any_seeking_uids -= set(over_cap) + + # Hysterectomy removes HMB + self.sim.people.menstruation.hmb[hyst_uids] = False + + # ------------------------------------------------ Treatment assignment + + def _assign_treatments(self, seeking_uids): + """ + For each care-seeker not on treatment: + 1. Global prob_offer gate + 2. Among those offered, random draw from tx_weights + 3. If assigned treatment is ineligible, redraw from remaining + 4. Accept based on propensity + """ + if len(seeking_uids) == 0: + return + + on_any = self.on_any_treatment + candidates = seeking_uids[~on_any[seeking_uids]] + if len(candidates) == 0: + return + + # ── Global prob_offer gate ── + if self.sim.t.now() >= self.pars.intv_year: + offer_p = self.pars.prob_offer_post + else: + offer_p = self.pars.prob_offer_pre + + self._p_offer.set(offer_p) + offered_any = self._p_offer.filter(candidates) + if len(offered_any) == 0: + return + + # Select weight set + if self.sim.t.now() >= self.pars.intv_year: + tx_probs = self._tx_probs_post + else: + tx_probs = self._tx_probs_pre + + if tx_probs.sum() == 0: + return + + # ── Per-person eligible weights and draw ── + # Build eligibility mask per treatment per person + n = len(offered_any) + eligible_mask = np.ones((n, self._n_tx), dtype=bool) + + for tx_idx in range(self._n_tx): + tx_name = self._tx_names[tx_idx] + tx = self.treatments[tx_name] + + # Tried-and-failed → ineligible + tried = np.asarray(tx.tried_treatment[offered_any]) + was_eff = np.asarray(tx.was_effective[offered_any]) + tried_and_failed = tried & ~was_eff + eligible_mask[tried_and_failed, tx_idx] = False + + # Pill/hIUD need no fertility intent + if tx_name in ('pill', 'hiud'): + fert = np.asarray(self.sim.people.fp.fertility_intent[offered_any]) + eligible_mask[fert, tx_idx] = False + + # Build per-person probability vectors + # Start with global weights, zero out ineligible, renormalize + person_probs = np.tile(tx_probs, (n, 1)) # (n, n_tx) + person_probs[~eligible_mask] = 0.0 + + # Renormalize per person + row_sums = person_probs.sum(axis=1, keepdims=True) + has_options = (row_sums > 0).flatten() + + if not has_options.any(): + return + + # Only draw for people who have at least one eligible treatment + draw_uids = offered_any[has_options] + draw_probs = person_probs[has_options] + draw_probs = draw_probs / draw_probs.sum(axis=1, keepdims=True) + + # Draw one treatment per person + draws = np.array([ + np.random.choice(self._n_tx, p=p) for p in draw_probs + ]) + + # ── Route to treatments ── + for tx_idx in range(self._n_tx): + tx_name = self._tx_names[tx_idx] + tx = self.treatments[tx_name] + assigned_mask = draws == tx_idx + assigned_uids = draw_uids[assigned_mask] + + if len(assigned_uids) == 0: + continue + + tx.offered[assigned_uids] = True + + # Accept (propensity < threshold) + threshold = tx.pars.prob_accept.pars['p'] + accepts = tx.accept_propensity[assigned_uids] < threshold + accepted = assigned_uids[accepts] + if len(accepted) == 0: + continue + + tx.accepted[accepted] = True + tx.tried_treatment[accepted] = True + tx._start_treatment(accepted) + + # ------------------------------------------------ Properties + + @property + def on_any_treatment(self): + on_any = self.treatments['nsaid'].on_treatment.copy() + for name in ['txa', 'pill', 'hiud']: + on_any |= self.treatments[name].on_treatment + return on_any + + @property + def tried_nsaid(self): + return self.treatments['nsaid'].tried_treatment + + @property + def tried_txa(self): + return self.treatments['txa'].tried_treatment + + @property + def tried_pill(self): + return self.treatments['pill'].tried_treatment + + @property + def tried_hiud(self): + return self.treatments['hiud'].tried_treatment + + @property + def n_gave_up(self): + return np.count_nonzero(self.gave_up) + + @property + def n_hysterectomy(self): + return np.count_nonzero(self.had_hysterectomy) + + def finalize(self): + super().finalize() + for tx in self.treatments.values(): + tx.finalize() + + +# ============================================================================ +# SECTION 4: COUNTERFACTUAL (STATUS QUO, NO hIUD) +# ============================================================================ +# Runs from 2020. Same pool model but with hIUD weight = 0 always. +# This is the baseline that intervention scenarios are compared against. +# ============================================================================ + +class HMBCounterfactual(HMBPool): + """ + Status quo: pool model with no hIUD, running from start of sim. + + Uses Darcy's baseline parameters: + 10% ever-seek, 1/36 first-visit rate, 70% receipt, + 50/25/25 NSAID/TXA/Pill split, no hIUD. + """ + + def __init__(self, pars=None, **kwargs): + default_pars = dict( + year=2020, + intv_year=9999, + prob_offer_pre=0.70, + prob_offer_post=0.70, + care_behavior_pre=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + care_behavior_post=sc.objdict( # Same as pre — never changes + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + tx_weights_pre=sc.objdict(nsaid=0.50, txa=0.25, pill=0.25, hiud=0.0), + tx_weights_post=sc.objdict(nsaid=0.50, txa=0.25, pill=0.25, hiud=0.0), + ) + merged = sc.mergedicts(default_pars, pars) + merged.setdefault('tx_weights_pre', {})['hiud'] = 0.0 + merged.setdefault('tx_weights_post', {})['hiud'] = 0.0 + super().__init__(pars=merged, **kwargs) + self.name = 'hmb_counterfactual' + + +# ============================================================================ +# SECTION 5: CASCADE MODEL (OPTIONAL, FOR COMPARISON) +# ============================================================================ +# Sequential: NSAID → TXA → Pill → hIUD. +# Each tier requires failure/refusal of the previous. +# Uses same three-layer care-seeking and adherence as the pool. +# ============================================================================ + +# class HMBCascade(ss.Intervention): +# """ +# Sequential cascade with three-layer care-seeking. + +# Treatment order: NSAID → TXA → Pill → hIUD. +# Each tier's eligibility function requires trying/refusing the previous. +# """ + +# def __init__(self, pars=None, **kwargs): +# super().__init__(name='hmb_cascade', **kwargs) + +# self.define_pars( +# year=2020, +# time_to_assess=ss.months(3), + +# nsaid=sc.objdict(efficacy=0.33, adherence=0.80, +# prob_offer=ss.bernoulli(p=0.7), +# prob_accept=ss.bernoulli(p=0.7)), +# txa=sc.objdict(efficacy=0.45, adherence=0.70, +# prob_offer=ss.bernoulli(p=0.7), +# prob_accept=ss.bernoulli(p=0.6)), +# pill=sc.objdict(efficacy=0.59, adherence=0.80, +# prob_offer=ss.bernoulli(p=0.7), +# prob_accept=ss.bernoulli(p=0.5)), +# hiud=sc.objdict(efficacy=0.88, adherence=1.00, +# prob_offer=ss.bernoulli(p=0.7), +# prob_accept=ss.bernoulli(p=0.5)), + +# care_behavior=sc.objdict( +# p_ever_seek=0.10, +# p_monthly_first=1/36, +# p_monthly_first_anemic=1/12, +# p_monthly_first_pain=1/24, +# p_monthly_repeat=1/6, +# ), + +# max_care_episodes=3, +# p_hysterectomy_at_cap=0.10, +# ) +# self.update_pars(pars, **kwargs) + +# # Orchestrator-level states (same as pool) +# self.define_states( +# ss.BoolState('ever_seeker'), +# ss.BoolState('ever_seeker_assigned'), +# ss.FloatArr('care_episodes', default=0), +# ss.BoolState('gave_up'), +# ss.BoolState('had_hysterectomy'), +# ) + +# self._p_ever_seek = ss.bernoulli(p=0) +# self._p_hysterectomy = ss.bernoulli(p=0) +# self.treatments = {} +# return + +# def init_pre(self, sim): +# super().init_pre(sim) + +# shared = dict( +# year=self.pars.year, +# time_to_assess=self.pars.time_to_assess, +# care_behavior=self.pars.care_behavior, +# ) + +# nsaid = NSAIDTreatment(pars=dict(**shared, +# efficacy=self.pars.nsaid.efficacy, +# adherence=self.pars.nsaid.adherence, +# prob_offer=self.pars.nsaid.prob_offer, +# prob_accept=self.pars.nsaid.prob_accept)) + +# def txa_elig(sim): +# return nsaid.tried_treatment | (nsaid.offered & ~nsaid.on_treatment) + +# txa = TXATreatment(pars=dict(**shared, +# efficacy=self.pars.txa.efficacy, +# adherence=self.pars.txa.adherence, +# prob_offer=self.pars.txa.prob_offer, +# prob_accept=self.pars.txa.prob_accept), +# eligibility=txa_elig) + +# def pill_elig(sim): +# return (txa.tried_treatment | +# (txa.offered & ~txa.on_treatment) | +# (nsaid.tried_treatment & ~nsaid.on_treatment & ~txa.offered)) + +# pill = PillTreatment(pars=dict(**shared, +# efficacy=self.pars.pill.efficacy, +# adherence=self.pars.pill.adherence, +# prob_offer=self.pars.pill.prob_offer, +# prob_accept=self.pars.pill.prob_accept), +# eligibility=pill_elig) + +# def hiud_elig(sim): +# return (pill.tried_treatment | +# (pill.offered & ~pill.on_treatment) | +# (nsaid.tried_treatment & ~nsaid.on_treatment & ~txa.offered) | +# (txa.tried_treatment & ~txa.on_treatment & ~pill.offered)) + +# hiud = hIUDTreatment(pars=dict(**shared, +# efficacy=self.pars.hiud.efficacy, +# adherence=self.pars.hiud.adherence, +# prob_offer=self.pars.hiud.prob_offer, +# prob_accept=self.pars.hiud.prob_accept), +# eligibility=hiud_elig) + +# self.treatments = {'nsaid': nsaid, 'txa': txa, 'pill': pill, 'hiud': hiud} +# for tx in self.treatments.values(): +# tx.init_pre(sim) + +# def init_post(self): +# super().init_post() +# for tx in self.treatments.values(): +# tx.init_post() +# self._prev_any_seeking_uids = set() + +# # Reuse the same Layer 1 and Layer 3 logic as pool +# def _assign_ever_seekers(self): +# ppl = self.sim.people +# new_hmb = (ppl.menstruation.hmb & +# ppl.menstruation.menstruating & +# ~self.ever_seeker_assigned).uids +# if len(new_hmb) == 0: +# return +# self.ever_seeker_assigned[new_hmb] = True +# self._p_ever_seek.set(self.pars.care_behavior.p_ever_seek) +# assigned_true = self._p_ever_seek.filter(new_hmb) +# self.ever_seeker[assigned_true] = True + +# def _get_anyone_seeking_uids(self): +# uids = set() +# all_uids = np.asarray(self.sim.people.uid) +# for tx in self.treatments.values(): +# seeking = np.asarray(tx.seeking_care) +# uids.update(all_uids[seeking]) +# return uids + +# def _enforce_episode_cap(self): +# now_seeking = self._get_anyone_seeking_uids() +# newly_seeking = np.array( +# list(now_seeking - self._prev_any_seeking_uids), dtype=int) +# self._prev_any_seeking_uids = set(now_seeking) + +# if len(newly_seeking) == 0: +# return +# self.care_episodes[newly_seeking] += 1 + +# over_cap = newly_seeking[ +# self.care_episodes[newly_seeking] > self.pars.max_care_episodes] +# if len(over_cap) == 0: +# return + +# self._p_hysterectomy.set(self.pars.p_hysterectomy_at_cap) +# hyst_uids = self._p_hysterectomy.filter(over_cap) +# giveup_uids = np.setdiff1d(over_cap, hyst_uids) + +# self.had_hysterectomy[hyst_uids] = True +# self.gave_up[giveup_uids] = True + +# for tx in self.treatments.values(): +# tx.seeking_care[over_cap] = False +# self._prev_any_seeking_uids -= set(over_cap) +# self.sim.people.menstruation.hmb[hyst_uids] = False + +# def step(self): +# """Cascade step: Layer 1 → seek → cap → offer/assess/adherence/continue.""" +# self._assign_ever_seekers() + +# for tx in self.treatments.values(): +# tx.step_seek() + +# self._enforce_episode_cap() + +# for tx in self.treatments.values(): +# tx.step_treat() + +# def finalize(self): +# super().finalize() +# for tx in self.treatments.values(): +# tx.finalize() + +# @property +# def on_any_treatment(self): +# on_any = self.treatments['nsaid'].on_treatment.copy() +# for name in ['txa', 'pill', 'hiud']: +# on_any |= self.treatments[name].on_treatment +# return on_any + +# @property +# def n_gave_up(self): +# return np.count_nonzero(self.gave_up) + +# @property +# def n_hysterectomy(self): +# return np.count_nonzero(self.had_hysterectomy) + + +# ============================================================================ +# SECTION 6: FACTORY FUNCTIONS +# ============================================================================ + +def make_pool_sim(seed=0, pool_pars=None, counterfactual=False): + """ + Create a simulation with the pool (or counterfactual) intervention. + + Args: + seed: Random seed + pool_pars: Dict of overrides for HMBPool/HMBCounterfactual + counterfactual: If True, use HMBCounterfactual (no hIUD) + """ + import fpsim as fp + from menstruation import Menstruation + from education import Education + from analyzers import track_hmb_anemia + + mens = Menstruation() + edu = Education() + + if counterfactual: + intervention = HMBCounterfactual(pars=pool_pars) + else: + intervention = HMBPool(pars=pool_pars) + + sim = fp.Sim( + start=2020, stop=2030, + n_agents=5000, total_pop=55_000_000, + location='kenya', + education_module=edu, + connectors=[mens], + interventions=[intervention], + analyzers=[track_hmb_anemia()], + rand_seed=seed, verbose=0, + ) + return sim + + +# def make_cascade_sim(seed=0, **cascade_kwargs): +# """Create a simulation with the cascade intervention.""" +# import fpsim as fp +# from menstruation import Menstruation +# from education import Education +# from analyzers import track_hmb_anemia + +# mens = Menstruation() +# edu = Education() +# cascade = HMBCascade(**cascade_kwargs) + +# sim = fp.Sim( +# start=2020, stop=2030, +# n_agents=5000, total_pop=55_000_000, +# location='kenya', +# education_module=edu, +# connectors=[mens], +# interventions=[cascade], +# analyzers=[track_hmb_anemia()], +# rand_seed=seed, verbose=0, +# ) +# return sim \ No newline at end of file diff --git a/menstruation.py b/menstruation.py index a15115b..8757dcf 100644 --- a/menstruation.py +++ b/menstruation.py @@ -45,7 +45,8 @@ def __init__(self, pars=None, name='menstruation', **kwargs): eff_hyst_menopause=ss.normal(-5, 1), # Adjustment for age of menopause if hysterectomy occurs # HMB prediction - p_hmb_prone=ss.bernoulli(p=0.486), # Proportion of menstruating women who experience HMB (sans interventions) + p_hmb_prone=ss.bernoulli(p=0.53), # 0.486 Proportion of menstruating women who experience HMB (sans interventions) which reflects some baseline level of treatment. + # Calibration to the baseline scenario shows that prevalence should be 53% so with the baseline treatment the p becomes 48.6% # Odds ratios to create an age curve (from UW Start) --- hmb_age_OR = { @@ -75,7 +76,8 @@ def __init__(self, pars=None, name='menstruation', **kwargs): ), pain=sc.objdict( # Parameters for menstrual pain base = 0.1, # Baseline probability of menstrual pain - hmb = 1.5, # Effect of HMB on menstrual pain - placeholder: prob of pain is 1/(1+np.exp(-(-np.log(1/0.1 -1)+1.5))) = 0.332 + hmb = 3.36, # Effect of HMB on menstrual pain - UW Start Estimated proportion of women with HMB that have dysmenorrhea: 76.2% (95% CI:59.4–89.6), + # prob of pain is 1/(1+np.exp(-(-np.log(1/0.1 -1)+3.35))) = 0.762 ), ), diff --git a/run_anemia_risk_sensitivity.py b/run_anemia_risk_sensitivity.py index 595bd28..bdd2b06 100644 --- a/run_anemia_risk_sensitivity.py +++ b/run_anemia_risk_sensitivity.py @@ -1,20 +1,21 @@ -# -*- coding: utf-8 -*- - """ -Sensitivity analysis: % reduction in anemia cases under low / mid / high RR -of anemia given HMB. +Sensitivity analysis: anemia reduction by hIUD uptake × RR of anemia given HMB. + +Structure: + 3 panels (low / mid / high hIUD uptake) + 3 RR lines per panel + counterfactual (dashed) per RR -Pooled OR: 2.17 (1.09–4.31) → RR: 1.73 (1.07–2.50) +All intervention sims: + 2020–2025: status quo (10% care-seeking, NSAID/TXA/Pill, no hIUD) + 2026–2030: mid care-seeking (20%) + hIUD at specified uptake level -Interventions compared: - - baseline : no intervention - - cascade : full HMBCascade (NSAID → TXA → Pill → hIUD) +Counterfactual sims: + 2020–2030: status quo throughout (10% care-seeking, no hIUD) -Architecture: new modular HMBCascade (v0.4.0) +RR varies from start of simulation (disease parameter). """ import numpy as np -import pandas as pd import sciris as sc import starsim as ss import fpsim as fp @@ -24,7 +25,7 @@ from menstruation import Menstruation from education import Education -from interventions import HMBCascade +from interventions_pool import HMBPool, HMBCounterfactual from analyzers import track_hmb_anemia @@ -36,13 +37,14 @@ # ── Settings ─────────────────────────────────────────────────────────────────── -P_BASE = 0.215 # baseline P(anemia | no HMB) – held fixed across scenarios -N_SEEDS = 10 # stochastic runs per scenario -START = 2020 -STOP = 2030 -INTV_YEAR = 2026 # year intervention begins - -# RR values derived from pooled OR 2.17 (95% CI: 1.09–4.31) +P_BASE = 0.215 +P_HMB_PRONE = 0.53 +N_SEEDS = 10 +START = 2020 +STOP = 2030 +INTV_YEAR = 2026 + +# ── RR scenarios ── rr_values = { 'low_rr': 1.07, 'mid_rr': 1.73, @@ -54,515 +56,515 @@ 'high_rr': 'High RR (2.50)', } rr_colors = { - 'low_rr': '#2196F3', # blue - 'mid_rr': '#4CAF50', # green - 'high_rr': '#F44336', # red + 'low_rr': '#2196F3', + 'mid_rr': '#4CAF50', + 'high_rr': '#F44336', } +# ── Care-seeking ── +CARE_PRE = sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, +) + +CARE_POST = sc.objdict( + p_ever_seek=0.20, + p_monthly_first=1/12, + p_monthly_first_anemic=1/6, + p_monthly_first_pain=1/12, + p_monthly_repeat=1/6, +) + +# ── hIUD scenarios (from Darcy's table) ── +HIUD_SCENARIOS = { + 'low_hiud': sc.objdict( + label='Low hIUD (10% of seekers)', + tx_weights_post=sc.objdict(nsaid=31.5, txa=15.75, pill=15.75, hiud=10.0), + prob_offer_post=0.73, + ), + 'mid_hiud': sc.objdict( + label='Mid hIUD (25% of seekers)', + tx_weights_post=sc.objdict(nsaid=27.5, txa=13.75, pill=13.75, hiud=25.0), + prob_offer_post=0.80, + ), + 'high_hiud': sc.objdict( + label='High hIUD (40% of seekers)', + tx_weights_post=sc.objdict(nsaid=25.0, txa=12.5, pill=12.5, hiud=40.0), + prob_offer_post=0.90, + ), +} -# ── Helpers ──────────────────────────────────────────────────────────────────── -def rr_to_logistic_coeff(rr, p_base=P_BASE): - """ - Convert a risk ratio to a logistic regression coefficient. - - We want P(anemia | HMB=True) = p_base * rr. - The logistic coefficient is the shift in log-odds needed to move - from the baseline probability to the HMB probability. +# Colors for hIUD scenarios (used in % reduction plots) +HIUD_COLORS = { + 'low_hiud': '#90CAF9', + 'mid_hiud': '#4CAF50', + 'high_hiud': '#F44336', +} - Args: - rr: Risk ratio of anemia given HMB - p_base: Baseline probability of anemia (no HMB) - Returns: - Logistic coefficient for HMB effect on anemia - """ - p_hmb = p_base * rr - # Clip to avoid log(0) or log of values >= 1 - p_hmb = np.clip(p_hmb, 1e-6, 1 - 1e-6) - coeff = (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) - return coeff +# ── Helpers ──────────────────────────────────────────────────────────────────── +def rr_to_logistic_coeff(rr, p_base=P_BASE): + p_hmb = np.clip(p_base * rr, 1e-6, 1 - 1e-6) + return (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) def make_menstruation(rr): - """ - Build a Menstruation connector with anemia risk set by the given RR. - - Only the hmb coefficient in hmb_seq.anemic changes across scenarios. - All other Menstruation parameters use defaults. - - Args: - rr: Risk ratio of anemia given HMB - - Returns: - Menstruation connector instance - """ coeff = rr_to_logistic_coeff(rr) - mens_pars = { + return Menstruation(pars={ + 'p_hmb_prone': ss.bernoulli(p=P_HMB_PRONE), 'hmb_seq': sc.objdict( - poor_mh=sc.objdict(base=0.4, hmb=1.0), - anemic =sc.objdict(base=P_BASE, hmb=coeff), # <-- varies by RR - pain =sc.objdict(base=0.1, hmb=1.5), + poor_mh=sc.objdict(base=0.4, hmb=1.0), + anemic=sc.objdict(base=P_BASE, hmb=coeff), + pain=sc.objdict(base=0.1, hmb=3.36), ) - } - return Menstruation(pars=mens_pars) + }) -def make_cascade(): - """ - Build a fresh HMBCascade instance. +def make_counterfactual(): + """Status quo throughout: 10% care, no hIUD, never shifts.""" + return HMBCounterfactual(pars=dict( + care_behavior_pre=CARE_PRE, + care_behavior_post=CARE_PRE, # Same as pre — never changes + )) - Called fresh each time to avoid shared state across simulations. - Returns: - HMBCascade intervention instance +def make_pool_intervention(hiud_scenario): """ - return HMBCascade( - pars=dict( - year=INTV_YEAR, - time_to_assess=ss.months(3), - ) - ) - - -def make_sim(rr, with_intervention=False, seed=0): + 2020–2025: status quo (10% care, no hIUD, 70% receipt) + 2026+: mid care-seeking (20%) + hIUD at specified level """ - Build a simulation with the specified anemia RR. - - Args: - rr: Risk ratio of anemia given HMB - with_intervention: Whether to include HMBCascade intervention - seed: Random seed + scen = HIUD_SCENARIOS[hiud_scenario] + + return HMBPool(pars=dict( + year=2020, + intv_year=INTV_YEAR, + care_behavior_pre=CARE_PRE, + care_behavior_post=CARE_POST, + prob_offer_post=scen.prob_offer_post, + tx_weights_pre=sc.objdict(nsaid=0.50, txa=0.25, pill=0.25, hiud=0.0), + tx_weights_post=scen.tx_weights_post, + nsaid=sc.objdict(efficacy=0.33, adherence=0.80), + txa=sc.objdict(efficacy=0.45, adherence=0.70), + pill=sc.objdict(efficacy=0.59, adherence=0.80), + hiud=sc.objdict(efficacy=0.88, adherence=1.00), + )) + + +def make_sim(rr, scenario, seed=0): + """ + Build simulation. - Returns: - fp.Sim instance ready to run + scenario: 'counterfactual' or one of HIUD_SCENARIOS keys """ mens = make_menstruation(rr) - edu = Education() - hmb_anemia_analyzer = track_hmb_anemia() - - sim_kwargs = dict( - start=START, - stop=STOP, - n_agents=10000, - total_pop=55_000_000, + edu = Education() + + if scenario == 'counterfactual': + intervention = make_counterfactual() + else: + intervention = make_pool_intervention(scenario) + + return fp.Sim( + start=START, stop=STOP, + n_agents=10000, total_pop=55_000_000, location='kenya', education_module=edu, connectors=[mens], - analyzers=[hmb_anemia_analyzer], - rand_seed=seed, - verbose=0, + interventions=[intervention], + analyzers=[track_hmb_anemia()], + rand_seed=seed, verbose=0, ) - if with_intervention: - sim_kwargs['interventions'] = [make_cascade()] - return fp.Sim(**sim_kwargs) +# ── Run ──────────────────────────────────────────────────────────────────────── +def _annualize(monthly_arr, how='sum'): + arr = np.asarray(monthly_arr) + n_years = len(arr) // 12 + arr = arr[:12 * n_years].reshape(n_years, 12) + return arr.sum(axis=1) if how == 'sum' else arr[:, -1] -# ── Run simulations ──────────────────────────────────────────────────────────── def run_sensitivity(force_rerun=True): - """ - Run baseline and intervention simulations for each RR value. - - For each RR x seed combination, runs a matched pair: - - baseline sim (no intervention) - - cascade sim (with HMBCascade) - - Averted cases are computed within each seed before aggregating, - which removes stochastic noise from the comparison. - - Returns: - raw: dict with structure - raw[rr_name]['baseline'] = list of annual anemia arrays (one per seed) - raw[rr_name]['cascade'] = list of annual anemia arrays (one per seed) - raw[rr_name]['averted'] = list of (baseline - cascade) arrays - """ - results_file = OUTFOLDER + 'anemia_sa_rr_raw.obj' + results_file = OUTFOLDER + 'anemia_sa_hiud_rr_cf_raw.obj' if not force_rerun and os.path.exists(results_file): print("Loading saved results...") return sc.loadobj(results_file) - raw = { - rr_name: { - 'baseline': [], 'cascade': [], 'averted': [], - 'baseline_monthly': [], 'cascade_monthly': [], - 'baseline_hmb_monthly': [], 'cascade_hmb_monthly': [], - } - for rr_name in rr_values - } + # Structure: raw[rr_name][scenario] where scenario is 'counterfactual' or hiud name + all_scenarios = ['counterfactual'] + list(HIUD_SCENARIOS.keys()) + raw = {} + for rr_name in rr_values: + raw[rr_name] = {} + for scen in all_scenarios: + raw[rr_name][scen] = { + 'monthly': [], 'hmb_monthly': [], 'annual': [], + } for rr_name, rr in rr_values.items(): - p_hmb = P_BASE * rr - print(f"\n=== {rr_labels[rr_name]} (RR={rr:.2f}, " - f"P(anemia|HMB)={p_hmb:.3f}) ===") + print(f"\n{'='*60}") + print(f" {rr_labels[rr_name]} (RR={rr:.2f})") + print(f"{'='*60}") for seed in range(N_SEEDS): - print(f" seed {seed}...", end=" ", flush=True) - - # Run matched pair with same seed and same RR - sims = [ - make_sim(rr, with_intervention=False, seed=seed), # baseline - make_sim(rr, with_intervention=True, seed=seed), # cascade - ] - msim = ss.MultiSim(sims) - msim.run() - - s_base = msim.sims[0] - s_cascade = msim.sims[1] - - # Extract annual total anemia from track_hmb_anemia analyzer - # n_anemia_total is monthly; sum within each year - base_monthly = s_base.results.track_hmb_anemia['n_anemia_total'] - cascade_monthly = s_cascade.results.track_hmb_anemia['n_anemia_total'] - - raw[rr_name]['baseline_monthly'].append(np.asarray(base_monthly)) - raw[rr_name]['cascade_monthly'].append(np.asarray(cascade_monthly)) - - raw[rr_name]['baseline_hmb_monthly'].append( - np.asarray(s_base.results.track_hmb_anemia['n_anemia_with_hmb'])) - raw[rr_name]['cascade_hmb_monthly'].append( - np.asarray(s_cascade.results.track_hmb_anemia['n_anemia_with_hmb'])) - - base_annual = _annualize(base_monthly) - cascade_annual = _annualize(cascade_monthly) - averted_annual = base_annual - cascade_annual - - raw[rr_name]['baseline'].append(base_annual) - raw[rr_name]['cascade'].append(cascade_annual) - raw[rr_name]['averted'].append(averted_annual) - - del msim - gc.collect() - print("done") + print(f" seed {seed}:", end="", flush=True) - sc.saveobj(results_file, raw) - print(f"\nSaved raw results: {results_file}") - return raw + for scen in all_scenarios: + print(f" {scen}...", end="", flush=True) + sim = make_sim(rr, scen, seed=seed) + sim.run() + monthly = np.asarray(sim.results.track_hmb_anemia['n_anemia_total']) + hmb_monthly = np.asarray(sim.results.track_hmb_anemia['n_anemia_with_hmb']) -def _annualize(monthly_arr, how='sum'): - """ - Convert monthly array to annual by summing (or taking end-of-year value). + raw[rr_name][scen]['monthly'].append(monthly) + raw[rr_name][scen]['hmb_monthly'].append(hmb_monthly) + raw[rr_name][scen]['annual'].append(_annualize(monthly)) - Args: - monthly_arr: Array of monthly values - how: 'sum' for annual totals, 'eoy' for end-of-year snapshot + del sim; gc.collect() + print(" done") - Returns: - Annual array - """ - arr = np.asarray(monthly_arr) - n_years = len(arr) // 12 - arr = arr[:12 * n_years].reshape(n_years, 12) - if how == 'sum': - return arr.sum(axis=1) - elif how == 'eoy': - return arr[:, -1] + sc.saveobj(results_file, raw) + print(f"\nSaved: {results_file}") + return raw -# ── Aggregate statistics ─────────────────────────────────────────────────────── -def compute_stats(raw): +# ── Statistics ───────────────────────────────────────────────────────────────── +def compute_pct_reduction(raw, rr_name, hiud_name): """ - Compute % reduction in anemia cases, aggregated across seeds. + % reduction in annual anemia: intervention vs counterfactual at the same RR. - % reduction = (baseline - cascade) / baseline * 100, computed per seed - then summarised as mean / 2.5th / 97.5th percentile. - - Args: - raw: Output from run_sensitivity() - - Returns: - stats: dict with structure - stats[rr_name] = {'mean', 'lower', 'upper'} arrays over years + Computed per-seed then aggregated, so stochastic noise cancels. """ - stats = {} - for rr_name in rr_values: - base_arr = np.array(raw[rr_name]['baseline']) # (n_seeds, n_years) - averted_arr = np.array(raw[rr_name]['averted']) # (n_seeds, n_years) - - # Compute % reduction per seed before aggregating - pct = np.where(base_arr > 0, averted_arr / base_arr * 100, np.nan) - - stats[rr_name] = { - 'mean': np.nanmean(pct, axis=0), - 'lower': np.nanpercentile(pct, 2.5, axis=0), - 'upper': np.nanpercentile(pct, 97.5, axis=0), - } - return stats - - -def compute_stats_hmb(raw): - stats = {} - for rr_name in rr_values: - base_arr = np.array([_annualize(m) for m in raw[rr_name]['baseline_hmb_monthly']]) - casc_arr = np.array([_annualize(m) for m in raw[rr_name]['cascade_hmb_monthly']]) - averted_arr = base_arr - casc_arr - pct = np.where(base_arr > 0, averted_arr / base_arr * 100, np.nan) - stats[rr_name] = { - 'mean': np.nanmean(pct, axis=0), - 'lower': np.nanpercentile(pct, 2.5, axis=0), - 'upper': np.nanpercentile(pct, 97.5, axis=0), - } - return stats + cf = np.array(raw[rr_name]['counterfactual']['annual']) + intv = np.array(raw[rr_name][hiud_name]['annual']) + averted = cf - intv + pct = np.where(cf > 0, averted / cf * 100, np.nan) + return { + 'mean': np.nanmean(pct, axis=0), + 'lower': np.nanpercentile(pct, 2.5, axis=0), + 'upper': np.nanpercentile(pct, 97.5, axis=0), + } # ── Plots ────────────────────────────────────────────────────────────────────── -def plot_annual_cases(raw, years, intervention_year=INTV_YEAR): - """ - Single panel: annual anemia cases for baseline and cascade intervention, - with RR uncertainty shown as shaded band. - Baseline: mean across seeds using mid RR (most likely estimate). - Intervention: solid line = mean of low/high RR means; - shaded band spans low-RR mean to high-RR mean. +def plot_monthly_panels(raw, years_monthly): + """ + 3 panels (one per hIUD level). Per panel: 3 RR lines (solid) + + 3 counterfactual lines (dashed, same color). """ - fig, ax = plt.subplots(figsize=(7, 4)) - - # Baseline: use mid RR as the representative estimate - base_mid = np.array(raw['mid_rr']['baseline']) # (n_seeds, n_years) - base_mean = base_mid.mean(axis=0) - base_lower = np.percentile(np.array(raw['low_rr']['baseline']), 2.5, axis=0) - base_upper = np.percentile(np.array(raw['high_rr']['baseline']), 97.5, axis=0) - - ax.plot(years, base_mean, color='#6c757d', lw=2.5, label='Baseline (mid RR)') - ax.fill_between(years, base_lower, base_upper, color='#6c757d', alpha=0.15, - label='Baseline RR uncertainty') - - # Cascade intervention: band spans low-RR mean to high-RR mean - casc_low = np.array(raw['low_rr']['cascade']).mean(axis=0) - casc_mid = np.array(raw['mid_rr']['cascade']).mean(axis=0) - casc_high = np.array(raw['high_rr']['cascade']).mean(axis=0) - - ax.plot(years, casc_mid, color='#2ca02c', lw=2.5, - label='HMB Cascade (mid RR)') - ax.fill_between(years, casc_low, casc_high, color='#2ca02c', alpha=0.20, - label='Cascade RR uncertainty') - - ax.axvline(intervention_year, color='k', ls='--', lw=1.5) - ylim = ax.get_ylim() - ax.text(intervention_year - 0.2, ylim[1] * 0.95, - 'Start of\nintervention', ha='right', va='top', - fontsize=9, color='#4d4d4d') - - ax.set_xlabel('Year') - ax.set_ylabel('Annual anemia cases') - ax.set_title('Annual anemia cases: baseline vs HMB cascade\n' - '(shaded band = RR uncertainty 1.07–2.50)') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle('Monthly anemia cases by hIUD uptake\n' + 'Solid = intervention, dashed = counterfactual (no hIUD)', + fontsize=13) + + for idx, (hiud_name, scen) in enumerate(HIUD_SCENARIOS.items()): + ax = axes[idx] + + for rr_name in rr_values: + color = rr_colors[rr_name] + + # Counterfactual (dashed) + cf = np.array(raw[rr_name]['counterfactual']['monthly']) + cf_mean = cf.mean(axis=0) + ax.plot(years_monthly, cf_mean, color=color, ls='--', lw=1.2, alpha=0.7) + + # Intervention (solid) + intv = np.array(raw[rr_name][hiud_name]['monthly']) + intv_mean = intv.mean(axis=0) + intv_std = intv.std(axis=0) + ax.plot(years_monthly, intv_mean, color=color, lw=1.5, + label=rr_labels[rr_name]) + ax.fill_between(years_monthly, intv_mean - intv_std, + intv_mean + intv_std, color=color, alpha=0.12) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('Monthly anemia cases') + ax.set_title(scen.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + sc.SIticks(ax=ax) plt.tight_layout() - outpath = PLOTFOLDER + 'anemia_annual_cases.png' + outpath = PLOTFOLDER + 'anemia_monthly_panels.png' fig.savefig(outpath, dpi=300, bbox_inches='tight') print(f"Saved: {outpath}") - plt.show() - return fig, ax + return fig -def plot_pct_reduction(stats, years, intervention_year=INTV_YEAR): +def plot_annual_panels(raw, years): """ - Single panel: % reduction in anemia cases post-intervention, - with separate lines for low / mid / high RR. - - Shaded bands show stochastic uncertainty (2.5th–97.5th percentile - across seeds) for each RR value. + 3 panels (one per hIUD level). Per panel: 3 RR lines (solid) + + 3 counterfactual lines (dashed). """ - post_mask = years >= intervention_year - - fig, ax = plt.subplots(figsize=(7, 4)) + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle('Annual anemia cases by hIUD uptake\n' + 'Solid = intervention, dashed = counterfactual', + fontsize=13) + + for idx, (hiud_name, scen) in enumerate(HIUD_SCENARIOS.items()): + ax = axes[idx] + + for rr_name in rr_values: + color = rr_colors[rr_name] + + # Counterfactual + cf = np.array(raw[rr_name]['counterfactual']['annual']) + cf_mean = cf.mean(axis=0) + ax.plot(years, cf_mean, color=color, ls='--', lw=1.5, alpha=0.7) + + # Intervention + intv = np.array(raw[rr_name][hiud_name]['annual']) + intv_mean = intv.mean(axis=0) + intv_std = intv.std(axis=0) + ax.errorbar(years, intv_mean, yerr=intv_std, color=color, + lw=2, marker='o', capsize=3, markersize=4, + label=rr_labels[rr_name]) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('Annual anemia cases') + ax.set_title(scen.label) + ax.set_xlim([START - 0.5, STOP + 0.5]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + sc.SIticks(ax=ax) - for rr_name in rr_values: - s = stats[rr_name] - - # Mask pre-intervention period (set to NaN so lines start at intervention year) - mean = np.where(post_mask, s['mean'], np.nan) - lower = np.where(post_mask, s['lower'], np.nan) - upper = np.where(post_mask, s['upper'], np.nan) - - ax.plot(years, mean, color=rr_colors[rr_name], lw=2.5, - label=rr_labels[rr_name]) - ax.fill_between(years, lower, upper, - color=rr_colors[rr_name], alpha=0.15) - - ax.axvline(intervention_year, color='k', ls='--', lw=1.5) - ylim = ax.get_ylim() - ax.text(intervention_year - 0.2, ylim[1] * 0.95, - 'Start of\nintervention', ha='right', va='top', - fontsize=9, color='#4d4d4d') - - ax.set_xlabel('Year') - ax.set_ylabel('% reduction in annual anemia cases') - ax.set_title('Sensitivity: % reduction in anemia cases\nby RR of anemia given HMB') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=10) + plt.tight_layout() + outpath = PLOTFOLDER + 'anemia_annual_panels.png' + fig.savefig(outpath, dpi=300, bbox_inches='tight') + print(f"Saved: {outpath}") + return fig + + +def plot_hmb_anemia_panels(raw, years_monthly): + """3 panels: monthly anemia among HMB women, with counterfactual.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle('Anemia among HMB women by hIUD uptake\n' + 'Solid = intervention, dashed = counterfactual', + fontsize=13) + + for idx, (hiud_name, scen) in enumerate(HIUD_SCENARIOS.items()): + ax = axes[idx] + + for rr_name in rr_values: + color = rr_colors[rr_name] + + # Counterfactual + cf = np.array(raw[rr_name]['counterfactual']['hmb_monthly']) + cf_mean = cf.mean(axis=0) + ax.plot(years_monthly, cf_mean, color=color, ls='--', lw=1.2, alpha=0.7) + + # Intervention + intv = np.array(raw[rr_name][hiud_name]['hmb_monthly']) + intv_mean = intv.mean(axis=0) + intv_std = intv.std(axis=0) + ax.plot(years_monthly, intv_mean, color=color, lw=1.5, + label=rr_labels[rr_name]) + ax.fill_between(years_monthly, intv_mean - intv_std, + intv_mean + intv_std, color=color, alpha=0.12) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('Monthly anemia cases (HMB women)') + ax.set_title(scen.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + sc.SIticks(ax=ax) plt.tight_layout() - outpath = PLOTFOLDER + 'anemia_pct_reduction_rr.png' + outpath = PLOTFOLDER + 'anemia_hmb_monthly_panels.png' fig.savefig(outpath, dpi=300, bbox_inches='tight') print(f"Saved: {outpath}") - plt.show() - return fig, ax + return fig -def plot_monthly_cases(raw, years_monthly, intervention_year=INTV_YEAR): +def plot_pct_reduction_by_hiud(raw, years): """ - Single panel: monthly anemia cases for baseline vs cascade, - with RR as separate colored lines. + 3 panels (one per hIUD level): % reduction vs counterfactual, by RR. - Baseline shown as dashed black (mid RR), intervention lines - colored by RR. Shaded bands = mean ± std across seeds. + This is the key result plot. Shows how much anemia is averted + by introducing hIUD, and how that depends on the RR assumption. """ - fig, ax = plt.subplots(figsize=(10, 5)) + post_mask = years >= INTV_YEAR - # Baseline: mid RR as representative - base_mid = np.array(raw['mid_rr']['baseline_monthly']) - base_mean = base_mid.mean(axis=0) - base_std = base_mid.std(axis=0) + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle('% reduction in annual anemia vs counterfactual among HMB women\n' + 'By hIUD uptake level and RR of anemia given HMB', + fontsize=13) - ax.plot(years_monthly, base_mean, color='#6c757d', lw=1.5, ls='--', - label='Baseline (mid RR)') - ax.fill_between(years_monthly, base_mean - base_std, base_mean + base_std, - color='#6c757d', alpha=0.15) + for idx, (hiud_name, scen) in enumerate(HIUD_SCENARIOS.items()): + ax = axes[idx] - # Intervention: one line per RR - for rr_name in rr_values: - casc_arr = np.array(raw[rr_name]['cascade_monthly']) - casc_mean = casc_arr.mean(axis=0) - casc_std = casc_arr.std(axis=0) - - ax.plot(years_monthly, casc_mean, color=rr_colors[rr_name], lw=1.5, - label=f'Cascade ({rr_labels[rr_name]})') - ax.fill_between(years_monthly, casc_mean - casc_std, casc_mean + casc_std, - color=rr_colors[rr_name], alpha=0.15) - - ax.axvline(intervention_year, color='k', ls='--', lw=1.5) - ylim = ax.get_ylim() - ax.text(intervention_year - 0.1, ylim[1] * 0.95, 'Start of\nintervention', - ha='right', va='top', fontsize=9, color='#4d4d4d') - - ax.set_xlabel('Year') - ax.set_ylabel('Monthly anemia cases') - ax.set_title('Monthly anemia cases: baseline vs HMB cascade\n' - 'by RR of anemia given HMB') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) - sc.SIticks(ax=ax) + for rr_name in rr_values: + s = compute_pct_reduction(raw, rr_name, hiud_name) + s = compute_pct_reduction_hmb(raw, rr_name, hiud_name) + + mean = np.where(post_mask, s['mean'], np.nan) + lower = np.where(post_mask, s['lower'], np.nan) + upper = np.where(post_mask, s['upper'], np.nan) + + ax.plot(years, mean, color=rr_colors[rr_name], lw=2.5, + label=rr_labels[rr_name]) + ax.fill_between(years, lower, upper, + color=rr_colors[rr_name], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('% reduction vs counterfactual') + ax.set_title(scen.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) plt.tight_layout() - outpath = PLOTFOLDER + 'anemia_monthly_cases.png' + outpath = PLOTFOLDER + 'anemia_pct_reduction_panels.png' fig.savefig(outpath, dpi=300, bbox_inches='tight') print(f"Saved: {outpath}") - plt.show() - return fig, ax + return fig -def plot_monthly_hmb_anemia(raw, years_monthly, intervention_year=INTV_YEAR): +def plot_pct_reduction_by_rr(raw, years): """ - Single panel: monthly anemia cases among women WITH HMB, - baseline vs cascade, by RR scenario. + 3 panels (one per RR): % reduction vs counterfactual, by hIUD level. + + Alternative view: for a given RR assumption, how much does + increasing hIUD uptake help? """ - fig, ax = plt.subplots(figsize=(10, 5)) + post_mask = years >= INTV_YEAR - # Baseline: mid RR - base_mid = np.array(raw['mid_rr']['baseline_hmb_monthly']) - base_mean = base_mid.mean(axis=0) - base_std = base_mid.std(axis=0) + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle('% reduction in annual anemia vs counterfactual among HMB women\n' + 'By RR assumption and hIUD uptake level', + fontsize=13) - ax.plot(years_monthly, base_mean, color='#6c757d', lw=1.5, ls='--', - label='Baseline (mid RR)') - ax.fill_between(years_monthly, base_mean - base_std, base_mean + base_std, - color='#6c757d', alpha=0.15) + for idx, rr_name in enumerate(rr_values): + ax = axes[idx] - for rr_name in rr_values: - casc_arr = np.array(raw[rr_name]['cascade_hmb_monthly']) - casc_mean = casc_arr.mean(axis=0) - casc_std = casc_arr.std(axis=0) - - ax.plot(years_monthly, casc_mean, color=rr_colors[rr_name], lw=1.5, - label=f'Cascade ({rr_labels[rr_name]})') - ax.fill_between(years_monthly, casc_mean - casc_std, casc_mean + casc_std, - color=rr_colors[rr_name], alpha=0.15) - - ax.axvline(intervention_year, color='k', ls='--', lw=1.5) - ylim = ax.get_ylim() - ax.text(intervention_year - 0.1, ylim[1] * 0.95, 'Start of\nintervention', - ha='right', va='top', fontsize=9, color='#4d4d4d') - - ax.set_xlabel('Year') - ax.set_ylabel('Monthly anemia cases among women with HMB') - ax.set_title('Anemia among women with HMB: baseline vs cascade\n' - 'by RR of anemia given HMB') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) - sc.SIticks(ax=ax) + for hiud_name, scen in HIUD_SCENARIOS.items(): + s = compute_pct_reduction(raw, rr_name, hiud_name) + s = compute_pct_reduction_hmb(raw, rr_name, hiud_name) + + mean = np.where(post_mask, s['mean'], np.nan) + lower = np.where(post_mask, s['lower'], np.nan) + upper = np.where(post_mask, s['upper'], np.nan) + + ax.plot(years, mean, color=HIUD_COLORS[hiud_name], lw=2.5, + label=scen.label) + ax.fill_between(years, lower, upper, + color=HIUD_COLORS[hiud_name], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('% reduction vs counterfactual') + ax.set_title(rr_labels[rr_name]) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) plt.tight_layout() - outpath = PLOTFOLDER + 'anemia_monthly_hmb_cases.png' + outpath = PLOTFOLDER + 'anemia_pct_reduction_by_rr.png' fig.savefig(outpath, dpi=300, bbox_inches='tight') print(f"Saved: {outpath}") - plt.show() - return fig, ax + return fig + + +def compute_pct_reduction_hmb(raw, rr_name, hiud_name): + """% reduction in HMB-specific anemia (where the signal lives).""" + cf = np.array([_annualize(m) for m in raw[rr_name]['counterfactual']['hmb_monthly']]) + intv = np.array([_annualize(m) for m in raw[rr_name][hiud_name]['hmb_monthly']]) + averted = cf - intv + pct = np.where(cf > 0, averted / cf * 100, np.nan) + return { + 'mean': np.nanmean(pct, axis=0), + 'lower': np.nanpercentile(pct, 2.5, axis=0), + 'upper': np.nanpercentile(pct, 97.5, axis=0), + } + # ── Summary table ────────────────────────────────────────────────────────────── -def print_summary(stats, years): - """Print mean % reduction averaged over post-intervention years.""" +def print_summary(raw, years): post_mask = years >= INTV_YEAR - print(f"\n{'─'*60}") - print("Mean % reduction in anemia cases (post-intervention, " - f"{INTV_YEAR}–{STOP})") - print(f"{'─'*60}") - print(f" {'RR scenario':<18} {'Mean %':>8} {'95% CI'}") - print(f"{'─'*60}") + print(f"\n{'═'*90}") + print(f" Anemia SA: % reduction vs counterfactual (no hIUD, 10% care)") + print(f" Intervention: mid care-seeking (20%) + hIUD, post-2026") + print(f"{'═'*90}") + + # End-of-sim absolute cases + print(f"\n End-of-sim annual anemia cases (millions):") + header = f" {'':>18} {'Counterfactual':>14}" + for hiud_name, scen in HIUD_SCENARIOS.items(): + header += f" {scen.label:>28}" + print(header) + print(f" {'─'*120}") for rr_name in rr_values: - s = stats[rr_name] - m = np.nanmean(s['mean'][post_mask]) - lo = np.nanmean(s['lower'][post_mask]) - hi = np.nanmean(s['upper'][post_mask]) - print(f" {rr_labels[rr_name]:<18} {m:>7.1f}% " - f"({lo:.1f}%–{hi:.1f}%)") + cf_arr = np.array(raw[rr_name]['counterfactual']['annual']) + cf_last = cf_arr[:, -1].mean() / 1e6 + line = f" {rr_labels[rr_name]:<18} {cf_last:>10.2f}M" + + for hiud_name in HIUD_SCENARIOS: + arr = np.array(raw[rr_name][hiud_name]['annual']) + m = arr[:, -1].mean() / 1e6 + line += f" {m:>24.2f}M" + print(line) + + # % reduction vs counterfactual + print(f"\n % reduction in anemia vs counterfactual") + print(f" (averaged over post-intervention years {INTV_YEAR}–{STOP})") - print(f"{'─'*60}\n") + for hiud_name, scen in HIUD_SCENARIOS.items(): + print(f"\n {scen.label}") + print(f" {'RR scenario':<18} {'Mean %':>8} {'95% CI'}") + print(f" {'─'*50}") + + for rr_name in rr_values: + s = compute_pct_reduction(raw, rr_name, hiud_name) + s = compute_pct_reduction_hmb(raw, rr_name, hiud_name) + m = np.nanmean(s['mean'][post_mask]) + lo = np.nanmean(s['lower'][post_mask]) + hi = np.nanmean(s['upper'][post_mask]) + print(f" {rr_labels[rr_name]:<18} {m:>7.1f}% ({lo:.1f}%–{hi:.1f}%)") + + print(f"\n{'═'*90}\n") # ── Main ─────────────────────────────────────────────────────────────────────── if __name__ == '__main__': - do_run = True # Set False to load saved results - - # Run (or load) simulations + do_run = True raw = run_sensitivity(force_rerun=do_run) - # Build year axis from first result array - n_years = len(raw['mid_rr']['baseline'][0]) + # Build time axes + n_years = len(raw['mid_rr']['counterfactual']['annual'][0]) years_full = np.arange(START, START + n_years) - - n_months = len(raw['mid_rr']['baseline_monthly'][0]) - years_monthly = np.array([START + m / 12 for m in range(n_months)]) - # Compute % reduction statistics - stats = compute_stats(raw) - # stats = compute_stats_hmb(raw) + n_months = len(raw['mid_rr']['counterfactual']['monthly'][0]) + years_monthly = np.array([START + m / 12 for m in range(n_months)]) # Plots - plot_annual_cases(raw, years_full) - plot_pct_reduction(stats, years_full) - plot_monthly_cases(raw, years_monthly) - plot_monthly_hmb_anemia(raw, years_monthly) - - # Summary table - print_summary(stats, years_full) \ No newline at end of file + plot_monthly_panels(raw, years_monthly) + plot_annual_panels(raw, years_full) + plot_hmb_anemia_panels(raw, years_monthly) + plot_pct_reduction_by_hiud(raw, years_full) + plot_pct_reduction_by_rr(raw, years_full) + + # Summary + print_summary(raw, years_full) + + raw = sc.loadobj('results_anemia_sa/anemia_sa_hiud_rr_cf_raw.obj') + plot_pct_reduction_by_hiud(raw, years_full) + plot_pct_reduction_by_rr(raw, years_full) diff --git a/run_care_hiud_sensitivity.py b/run_care_hiud_sensitivity.py deleted file mode 100644 index 94bb20a..0000000 --- a/run_care_hiud_sensitivity.py +++ /dev/null @@ -1,1140 +0,0 @@ -""" -Sensitivity analysis: care-seeking probability × hIUD uptake target - -3 × 3 grid: - Care-seeking base: 10%, 20%, 35% - hIUD uptake target: ~5% (accept=0.20), ~10% (accept=0.35), ~15% (accept=0.50) NSAID/TXA/Pill acceptance fixed at 50%. - hIUD uptake target: ~5% (accept=0.15), ~10% (accept=0.30), ~15% (accept=0.45) NSAID/TXA/Pill acceptance fixed at 25%. - -Interventions compared: - - baseline : no intervention - - cascade : full HMBCascade (NSAID → TXA → Pill → hIUD) -""" - -import numpy as np -import pandas as pd -import sciris as sc -import starsim as ss -import fpsim as fp -import os -import gc -import matplotlib.pyplot as plt - -from menstruation import Menstruation -from education import Education -from interventions import HMBCascade -from analyzers import (track_care_seeking, track_tx_eff, track_tx_dur, - track_hmb_anemia, track_cascade, track_anemia_duration) - - -# ── Output folders ───────────────────────────────────────────────────────────── -PLOTFOLDER = 'figures_care_hiud_sa/' -OUTFOLDER = 'results_care_hiud_sa/' -for d in [PLOTFOLDER, OUTFOLDER]: - os.makedirs(d, exist_ok=True) - - -# ── Settings ─────────────────────────────────────────────────────────────────── -N_SEEDS = 10 -START = 2020 -STOP = 2030 -INTV_YEAR = 2026 - -# Fixed acceptance for NSAID, TXA, Pill -FIXED_ACCEPT = 0.50 - -# Care-seeking scenarios -# CARE_SCENARIOS = { -# '10%': sc.objdict(base=0.10, anemic=1.43, pain=0.61), -# '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), -# '35%': sc.objdict(base=0.35, anemic=0.32, pain=0.14), -# } - -# anemic and pain coefficients held fixed; only base rate varies -CARE_SCENARIOS = { - '10%': sc.objdict(base=0.10, anemic=0.86, pain=0.37), - '20%': sc.objdict(base=0.20, anemic=0.86, pain=0.37), - '35%': sc.objdict(base=0.35, anemic=0.86, pain=0.37), -} - -# hIUD uptake scenarios (acceptance probabilities from calibration with 50% fixed accept) -HIUD_SCENARIOS = { - '5%': 0.20, - '10%': 0.35, - '15%': 0.50, -} - -# Labels and colors -CARE_LABELS = { - '10%': 'Care 10%', - '20%': 'Care 20%', - '35%': 'Care 35%', -} -CARE_COLORS = { - '10%': '#d62728', - '20%': '#ff7f0e', - '35%': '#2196F3', -} -HIUD_LINESTYLES = { - '5%': '-', - '10%': '--', - '15%': ':', -} - - -# ── Helpers ──────────────────────────────────────────────────────────────────── -def _annualize(monthly_arr, how='sum'): - arr = np.asarray(monthly_arr) - n_years = len(arr) // 12 - arr = arr[:12 * n_years].reshape(n_years, 12) - if how == 'sum': - return arr.sum(axis=1) - elif how == 'eoy': - return arr[:, -1] - - -def scenario_key(care_label, hiud_label): - return f'{care_label}_hiud{hiud_label}' - - -# ── Simulation creation ─────────────────────────────────────────────────────── -def make_sim(care_behavior=None, hiud_accept=0.20, with_intervention=True, seed=0): - mens = Menstruation() - edu = Education() - - analyzers = [ - track_hmb_anemia(), - track_anemia_duration(), - ] - - sim_kwargs = dict( - start=START, - stop=STOP, - n_agents=10000, - total_pop=55_000_000, - location='kenya', - education_module=edu, - connectors=[mens], - analyzers=analyzers, - rand_seed=seed, - verbose=0, - ) - - if with_intervention: - cascade = HMBCascade( - pars=dict( - year=INTV_YEAR, - time_to_assess=ss.months(3), - care_behavior=care_behavior, - nsaid=sc.objdict( - efficacy=0.5, - adherence=0.7, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=FIXED_ACCEPT), - ), - txa=sc.objdict( - efficacy=0.6, - adherence=0.6, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=FIXED_ACCEPT), - ), - pill=sc.objdict( - efficacy=0.7, - adherence=0.75, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=FIXED_ACCEPT), - ), - hiud=sc.objdict( - efficacy=0.8, - adherence=0.85, - prob_offer=ss.bernoulli(p=0.9), - prob_accept=ss.bernoulli(p=hiud_accept), - ), - ) - ) - sim_kwargs['interventions'] = [cascade] - sim_kwargs['analyzers'].extend([ - track_care_seeking(), - track_tx_eff(), - track_tx_dur(), - track_cascade(), - ]) - - return fp.Sim(**sim_kwargs) - - -# ── Run simulations ──────────────────────────────────────────────────────────── -def run_sa(force_rerun=True): - results_file = OUTFOLDER + 'care_hiud_sa_raw.obj' - - if not force_rerun and os.path.exists(results_file): - print("Loading saved results...") - return sc.loadobj(results_file) - - raw = {} - for care_label in CARE_SCENARIOS: - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - raw[key] = { - 'baseline': [], 'cascade': [], 'averted': [], - 'baseline_monthly': [], 'cascade_monthly': [], - 'baseline_total_anemia_monthly': [], - 'cascade_total_anemia_monthly': [], - 'care_seeking_prev': [], - 'care_seeking_anemic': [], - 'care_seeking_not_anemic': [], - 'cascade_depth': [], - 'hiud_uptake': [], - 'care_label': care_label, - 'hiud_label': hiud_label, - 'baseline_hmb_prev': [], - 'cascade_hmb_prev': [], - 'baseline_disrupted': [], - 'cascade_disrupted': [], - 'baseline_n_disruptions': [], - 'cascade_n_disruptions': [], - } - - for care_label, care_behavior in CARE_SCENARIOS.items(): - for hiud_label, hiud_accept in HIUD_SCENARIOS.items(): - key = scenario_key(care_label, hiud_label) - print(f"\n=== {CARE_LABELS[care_label]}, hIUD target={hiud_label} " - f"(accept={hiud_accept:.2f}) ===") - - for seed in range(N_SEEDS): - print(f" seed {seed}...", end=" ", flush=True) - - sims = [ - make_sim(with_intervention=False, seed=seed), - make_sim(care_behavior=care_behavior, hiud_accept=hiud_accept, - with_intervention=True, seed=seed), - ] - msim = ss.MultiSim(sims) - msim.run() - - s_base = msim.sims[0] - s_cascade = msim.sims[1] - - # HMB-related anemia - base_monthly = s_base.results.track_hmb_anemia['n_anemia_with_hmb'] - cascade_monthly = s_cascade.results.track_hmb_anemia['n_anemia_with_hmb'] - - raw[key]['baseline_monthly'].append(np.asarray(base_monthly)) - raw[key]['cascade_monthly'].append(np.asarray(cascade_monthly)) - - # Total anemia - raw[key]['baseline_total_anemia_monthly'].append( - np.asarray(s_base.results.track_hmb_anemia['n_anemia_total'])) - raw[key]['cascade_total_anemia_monthly'].append( - np.asarray(s_cascade.results.track_hmb_anemia['n_anemia_total'])) - - # Care-seeking rates - raw[key]['care_seeking_prev'].append( - np.asarray(s_cascade.results.track_care_seeking['care_seeking_prev'])) - raw[key]['care_seeking_anemic'].append( - np.asarray(s_cascade.results.track_care_seeking['care_seeking_anemic'])) - raw[key]['care_seeking_not_anemic'].append( - np.asarray(s_cascade.results.track_care_seeking['care_seeking_not_anemic'])) - - # HMB prevalence - raw[key]['baseline_hmb_prev'].append( - np.asarray(s_base.results.menstruation['hmb_prev'])) - raw[key]['cascade_hmb_prev'].append( - np.asarray(s_cascade.results.menstruation['hmb_prev'])) - - # School disruptions - raw[key]['baseline_disrupted'].append( - np.asarray(s_base.results.edu['prop_disrupted'])) - raw[key]['cascade_disrupted'].append( - np.asarray(s_cascade.results.edu['prop_disrupted'])) - raw[key]['baseline_n_disruptions'].append( - np.asarray(s_base.results.edu['n_disruptions'])) - raw[key]['cascade_n_disruptions'].append( - np.asarray(s_cascade.results.edu['n_disruptions'])) - - # Cascade depth - cascade_intv = s_cascade.interventions.hmb_cascade - menstruating = s_cascade.people.menstruation.menstruating - n_treatments = ( - np.array(cascade_intv.tried_nsaid, dtype=int) + - np.array(cascade_intv.tried_txa, dtype=int) + - np.array(cascade_intv.tried_pill, dtype=int) + - np.array(cascade_intv.tried_hiud, dtype=int) - ) - total = np.count_nonzero(menstruating) - depth_dist = {} - for n in range(5): - count = np.count_nonzero((n_treatments == n) & menstruating) - depth_dist[n] = 100 * count / total if total > 0 else 0 - raw[key]['cascade_depth'].append(depth_dist) - - # hIUD uptake — corrected denominator includes women on treatment - hmb = s_cascade.people.menstruation.hmb - hmb_underlying = (hmb | cascade_intv.on_any_treatment) & menstruating - n_hmb = np.count_nonzero(hmb_underlying) - - ever_offered_nsaid = cascade_intv.treatments['nsaid'].offered & hmb_underlying - n_hmb_seekers = np.count_nonzero(ever_offered_nsaid) - tried_hiud_seekers = cascade_intv.treatments['hiud'].tried_treatment & ever_offered_nsaid - n_hiud_seekers = np.count_nonzero(tried_hiud_seekers) - - tried_hiud_hmb = cascade_intv.treatments['hiud'].tried_treatment & hmb_underlying - n_hiud_hmb = np.count_nonzero(tried_hiud_hmb) - - hiud_uptake = { - 'pct_of_hmb_seekers': 100 * n_hiud_seekers / n_hmb_seekers if n_hmb_seekers > 0 else 0, - 'pct_of_hmb': 100 * n_hiud_hmb / n_hmb if n_hmb > 0 else 0, - 'n_hmb_seekers': n_hmb_seekers, - 'n_hiud_among_seekers': n_hiud_seekers, - 'n_hmb': n_hmb, - 'n_hiud_hmb': n_hiud_hmb, - } - raw[key]['hiud_uptake'].append(hiud_uptake) - - # Annualize - base_annual = _annualize(base_monthly) - cascade_annual = _annualize(cascade_monthly) - averted_annual = base_annual - cascade_annual - - raw[key]['baseline'].append(base_annual) - raw[key]['cascade'].append(cascade_annual) - raw[key]['averted'].append(averted_annual) - - del msim - gc.collect() - print("done") - - sc.saveobj(results_file, raw) - print(f"\nSaved: {results_file}") - return raw - - -# ── Aggregate statistics ─────────────────────────────────────────────────────── -def compute_stats(raw): - stats = {} - for key in raw: - if key.startswith('_'): - continue - base_arr = np.array(raw[key]['baseline']) - averted_arr = np.array(raw[key]['averted']) - pct = np.where(base_arr > 0, averted_arr / base_arr * 100, np.nan) - stats[key] = { - 'mean': np.nanmean(pct, axis=0), - 'std': np.nanstd(pct, axis=0), - } - return stats - - -# ── Plots ────────────────────────────────────────────────────────────────────── -def plot_monthly_cases(raw, years_monthly): - """Monthly HMB-related anemia: one panel per hIUD scenario.""" - n_hiud = len(HIUD_SCENARIOS) - fig, axes = plt.subplots(1, n_hiud, figsize=(7 * n_hiud, 6)) - fig.suptitle('Monthly HMB-related anemia by care-seeking and hIUD uptake', fontsize=14) - - for idx, hiud_label in enumerate(HIUD_SCENARIOS): - ax = axes[idx] - - all_baselines = [] - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - all_baselines.extend(raw[key]['baseline_monthly']) - base_arr = np.array(all_baselines) - base_mean = base_arr.mean(axis=0) - base_std = base_arr.std(axis=0) - - ax.plot(years_monthly, base_mean, color='#6c757d', lw=1.5, ls='--', - label='No intervention') - ax.fill_between(years_monthly, base_mean - base_std, base_mean + base_std, - color='#6c757d', alpha=0.15) - - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - casc_arr = np.array(raw[key]['cascade_monthly']) - casc_mean = casc_arr.mean(axis=0) - casc_std = casc_arr.std(axis=0) - - ax.plot(years_monthly, casc_mean, color=CARE_COLORS[care_label], lw=1.5, - label=CARE_LABELS[care_label]) - ax.fill_between(years_monthly, casc_mean - casc_std, casc_mean + casc_std, - color=CARE_COLORS[care_label], alpha=0.15) - - ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) - ax.set_xlabel('Year') - ax.set_ylabel('Monthly HMB-related anemia cases') - ax.set_title(f'hIUD uptake target: {hiud_label}') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) - sc.SIticks(ax=ax) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_monthly_cases.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def plot_impact_comparison(raw, years_monthly): - """Total anemia and HMB-related anemia: one panel per hIUD scenario.""" - n_hiud = len(HIUD_SCENARIOS) - fig, axes = plt.subplots(2, n_hiud, figsize=(7 * n_hiud, 12)) - fig.suptitle('Intervention impact by care-seeking and hIUD uptake', fontsize=14) - - for col, hiud_label in enumerate(HIUD_SCENARIOS): - # Row 0: Total anemia - ax = axes[0, col] - all_base = [] - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - all_base.extend(raw[key]['baseline_total_anemia_monthly']) - base_arr = np.array(all_base) - base_mean = base_arr.mean(axis=0) - base_std = base_arr.std(axis=0) - - ax.plot(years_monthly, base_mean, color='#6c757d', lw=1.5, ls='--', label='No intervention') - ax.fill_between(years_monthly, base_mean - base_std, base_mean + base_std, - color='#6c757d', alpha=0.15) - - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - casc_arr = np.array(raw[key]['cascade_total_anemia_monthly']) - casc_mean = casc_arr.mean(axis=0) - casc_std = casc_arr.std(axis=0) - ax.plot(years_monthly, casc_mean, color=CARE_COLORS[care_label], lw=1.5, - label=CARE_LABELS[care_label]) - ax.fill_between(years_monthly, casc_mean - casc_std, casc_mean + casc_std, - color=CARE_COLORS[care_label], alpha=0.15) - - ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) - ax.set_title(f'Total anemia — hIUD target: {hiud_label}') - ax.set_xlabel('Year') - ax.set_ylabel('Total anemia cases') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=8) - sc.SIticks(ax=ax) - - # Row 1: HMB-related anemia - ax = axes[1, col] - all_base_hmb = [] - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - all_base_hmb.extend(raw[key]['baseline_monthly']) - base_arr = np.array(all_base_hmb) - base_mean = base_arr.mean(axis=0) - base_std = base_arr.std(axis=0) - - ax.plot(years_monthly, base_mean, color='#6c757d', lw=1.5, ls='--', label='No intervention') - ax.fill_between(years_monthly, base_mean - base_std, base_mean + base_std, - color='#6c757d', alpha=0.15) - - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - casc_arr = np.array(raw[key]['cascade_monthly']) - casc_mean = casc_arr.mean(axis=0) - casc_std = casc_arr.std(axis=0) - ax.plot(years_monthly, casc_mean, color=CARE_COLORS[care_label], lw=1.5, - label=CARE_LABELS[care_label]) - ax.fill_between(years_monthly, casc_mean - casc_std, casc_mean + casc_std, - color=CARE_COLORS[care_label], alpha=0.15) - - ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) - ax.set_title(f'HMB-related anemia — hIUD target: {hiud_label}') - ax.set_xlabel('Year') - ax.set_ylabel('HMB-related anemia cases') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=8) - sc.SIticks(ax=ax) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_impact_comparison.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def plot_pct_reduction_timeseries(stats, years): - """% reduction over time: all 9 scenarios on one panel.""" - post_mask = years >= INTV_YEAR - fig, ax = plt.subplots(figsize=(12, 7)) - - for care_label in CARE_SCENARIOS: - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - s = stats[key] - - mean = np.where(post_mask, s['mean'], np.nan) - upper = np.where(post_mask, s['mean'] + s['std'], np.nan) - lower = np.where(post_mask, s['mean'] - s['std'], np.nan) - - ax.plot(years, mean, color=CARE_COLORS[care_label], - ls=HIUD_LINESTYLES[hiud_label], lw=2.5, - label=f'{CARE_LABELS[care_label]}, hIUD {hiud_label}') - ax.fill_between(years, lower, upper, - color=CARE_COLORS[care_label], alpha=0.06) - - ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) - ylim = ax.get_ylim() - ax.text(INTV_YEAR - 0.2, ylim[1] * 0.95, 'Start of\nintervention', - ha='right', va='top', fontsize=9, color='#4d4d4d') - - ax.set_xlabel('Year') - ax.set_ylabel('% reduction in HMB-related anemia cases') - ax.set_title('% reduction in HMB-related anemia\nby care-seeking and hIUD uptake target') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=8, ncol=3) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_pct_reduction_timeseries.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def plot_pct_reduction_barchart(raw, stats, years): - """Grouped bar chart: % reduction and absolute averted.""" - post_mask = years >= INTV_YEAR - - fig, axes = plt.subplots(1, 2, figsize=(18, 7)) - fig.suptitle('Anemia reduction vs. no-intervention baseline\n' - '(post-intervention average)', fontsize=14) - - care_list = list(CARE_SCENARIOS.keys()) - hiud_list = list(HIUD_SCENARIOS.keys()) - n_care = len(care_list) - n_hiud = len(hiud_list) - - x = np.arange(n_care) - total_width = 0.75 - bar_width = total_width / n_hiud - - hiud_alphas = {'5%': 0.9, '10%': 0.65, '15%': 0.4} - - # Panel 1: % reduction - ax = axes[0] - for i, hiud_label in enumerate(hiud_list): - means = [] - stds = [] - for care_label in care_list: - key = scenario_key(care_label, hiud_label) - s = stats[key] - means.append(np.nanmean(s['mean'][post_mask])) - stds.append(np.nanmean(s['std'][post_mask])) - - offset = (i - (n_hiud - 1) / 2) * bar_width - bars = ax.bar(x + offset, means, bar_width, yerr=stds, capsize=4, - label=f'hIUD {hiud_label}', - alpha=hiud_alphas[hiud_label], - edgecolor='black', linewidth=0.5, - color=[CARE_COLORS[c] for c in care_list]) - - for bar, val, err in zip(bars, means, stds): - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() + err + 0.3, - f'{val:.1f}%', ha='center', va='bottom', - fontsize=8, fontweight='bold') - - ax.set_xlabel('Care-seeking scenario') - ax.set_ylabel('% reduction in HMB-related anemia') - ax.set_title('Percentage reduction') - ax.set_xticks(x) - ax.set_xticklabels([CARE_LABELS[c] for c in care_list]) - ax.legend(frameon=False, fontsize=10) - ax.grid(axis='y', alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - # Panel 2: Absolute averted - ax = axes[1] - for i, hiud_label in enumerate(hiud_list): - means = [] - stds = [] - for care_label in care_list: - key = scenario_key(care_label, hiud_label) - averted_arr = np.array(raw[key]['averted']) - post_averted = averted_arr[:, post_mask] - means.append(post_averted.mean()) - stds.append(post_averted.std(axis=0).mean()) - - offset = (i - (n_hiud - 1) / 2) * bar_width - bars = ax.bar(x + offset, [v / 1e6 for v in means], bar_width, - yerr=[v / 1e6 for v in stds], capsize=4, - label=f'hIUD {hiud_label}', - alpha=hiud_alphas[hiud_label], - edgecolor='black', linewidth=0.5, - color=[CARE_COLORS[c] for c in care_list]) - - for bar, val, err in zip(bars, means, stds): - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() + err / 1e6 + 0.01, - f'{val/1e6:.2f}m', ha='center', va='bottom', - fontsize=8, fontweight='bold') - - ax.set_xlabel('Care-seeking scenario') - ax.set_ylabel('Annual averted anemia cases (millions)') - ax.set_title('Absolute reduction') - ax.set_xticks(x) - ax.set_xticklabels([CARE_LABELS[c] for c in care_list]) - ax.legend(frameon=False, fontsize=10) - ax.grid(axis='y', alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_pct_reduction_barchart.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def plot_hiud_uptake(raw): - """hIUD uptake bar chart: grouped by care-seeking, colored by hIUD scenario.""" - fig, axes = plt.subplots(1, 2, figsize=(18, 6)) - fig.suptitle('hIUD uptake among women with HMB (underlying)', fontsize=14) - - care_list = list(CARE_SCENARIOS.keys()) - hiud_list = list(HIUD_SCENARIOS.keys()) - n_care = len(care_list) - n_hiud = len(hiud_list) - - x = np.arange(n_care) - total_width = 0.75 - bar_width = total_width / n_hiud - - hiud_bar_colors = {'5%': '#4CAF50', '10%': '#FF9800', '15%': '#F44336'} - - # Panel 1: % of HMB seekers - ax = axes[0] - for i, hiud_label in enumerate(hiud_list): - means = [] - stds = [] - for care_label in care_list: - key = scenario_key(care_label, hiud_label) - vals = [u['pct_of_hmb_seekers'] for u in raw[key]['hiud_uptake']] - means.append(np.mean(vals)) - stds.append(np.std(vals)) - - offset = (i - (n_hiud - 1) / 2) * bar_width - bars = ax.bar(x + offset, means, bar_width, yerr=stds, capsize=4, - label=f'hIUD target {hiud_label}', - color=hiud_bar_colors[hiud_label], alpha=0.7, - edgecolor='black', linewidth=0.5) - - for bar, val, err in zip(bars, means, stds): - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() + err + 0.3, - f'{val:.1f}%', ha='center', va='bottom', - fontsize=8, fontweight='bold') - - ax.set_xlabel('Care-seeking scenario') - ax.set_ylabel('% initiating hIUD') - ax.set_title('% of HMB women seeking care\nwho initiate hIUD') - ax.set_xticks(x) - ax.set_xticklabels([CARE_LABELS[c] for c in care_list]) - ax.legend(frameon=False, fontsize=10) - ax.grid(axis='y', alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - # Panel 2: % of all HMB women (underlying) - ax = axes[1] - for i, hiud_label in enumerate(hiud_list): - means = [] - stds = [] - for care_label in care_list: - key = scenario_key(care_label, hiud_label) - vals = [u['pct_of_hmb'] for u in raw[key]['hiud_uptake']] - means.append(np.mean(vals)) - stds.append(np.std(vals)) - - offset = (i - (n_hiud - 1) / 2) * bar_width - bars = ax.bar(x + offset, means, bar_width, yerr=stds, capsize=4, - label=f'hIUD target {hiud_label}', - color=hiud_bar_colors[hiud_label], alpha=0.7, - edgecolor='black', linewidth=0.5) - - for bar, val, err in zip(bars, means, stds): - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() + err + 0.3, - f'{val:.1f}%', ha='center', va='bottom', - fontsize=8, fontweight='bold') - - ax.set_xlabel('Care-seeking scenario') - ax.set_ylabel('% using hIUD') - ax.set_title('% of women with underlying HMB\nwho ever used hIUD') - ax.set_xticks(x) - ax.set_xticklabels([CARE_LABELS[c] for c in care_list]) - ax.legend(frameon=False, fontsize=10) - ax.grid(axis='y', alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_hiud_uptake.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def plot_cascade_comparison(raw): - """Cascade depth: one panel per hIUD scenario.""" - n_hiud = len(HIUD_SCENARIOS) - fig, axes = plt.subplots(1, n_hiud, figsize=(7 * n_hiud, 6)) - fig.suptitle('Cascade progression by care-seeking and hIUD uptake', fontsize=14) - - care_list = list(CARE_SCENARIOS.keys()) - n_care = len(care_list) - treatments_tried = np.arange(5) - total_width = 0.7 - bar_width = total_width / n_care - - for col, hiud_label in enumerate(HIUD_SCENARIOS): - ax = axes[col] - - for i, care_label in enumerate(care_list): - key = scenario_key(care_label, hiud_label) - depth_arrays = {n: [] for n in range(5)} - for depth_dist in raw[key]['cascade_depth']: - for n in range(5): - depth_arrays[n].append(depth_dist[n]) - - means = [np.mean(depth_arrays[n]) for n in range(5)] - stds = [np.std(depth_arrays[n]) for n in range(5)] - - offset = (i - (n_care - 1) / 2) * bar_width - ax.bar(treatments_tried + offset, means, bar_width, - yerr=stds, capsize=3, - label=CARE_LABELS[care_label], - color=CARE_COLORS[care_label], alpha=0.7, - edgecolor='black', linewidth=0.5) - - ax.set_xlabel('Number of treatments tried') - ax.set_ylabel('Percentage of menstruating women (%)') - ax.set_title(f'hIUD uptake target: {hiud_label}') - ax.set_xticks(treatments_tried) - ax.set_xticklabels(['0', '1', '2', '3', '4']) - ax.legend(frameon=False, fontsize=10) - ax.grid(axis='y', alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_cascade_depth.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - -def plot_hmb_prevalence(raw, years_monthly): - """HMB prevalence over time: one panel per hIUD scenario.""" - n_hiud = len(HIUD_SCENARIOS) - fig, axes = plt.subplots(1, n_hiud, figsize=(7 * n_hiud, 6)) - fig.suptitle('HMB prevalence among menstruating women', fontsize=14) - - for idx, hiud_label in enumerate(HIUD_SCENARIOS): - ax = axes[idx] - - # Baseline - all_base = [] - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - all_base.extend(raw[key]['baseline_hmb_prev']) - base_arr = np.array(all_base) * 100 - base_mean = base_arr.mean(axis=0) - base_std = base_arr.std(axis=0) - - ax.plot(years_monthly, base_mean, color='#6c757d', lw=1.5, ls='--', - label='No intervention') - ax.fill_between(years_monthly, base_mean - base_std, base_mean + base_std, - color='#6c757d', alpha=0.15) - - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - casc_arr = np.array(raw[key]['cascade_hmb_prev']) * 100 - casc_mean = casc_arr.mean(axis=0) - casc_std = casc_arr.std(axis=0) - - ax.plot(years_monthly, casc_mean, color=CARE_COLORS[care_label], lw=1.5, - label=CARE_LABELS[care_label]) - ax.fill_between(years_monthly, casc_mean - casc_std, casc_mean + casc_std, - color=CARE_COLORS[care_label], alpha=0.15) - - ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) - ax.set_xlabel('Year') - ax.set_ylabel('HMB prevalence (%)') - ax.set_title(f'hIUD uptake target: {hiud_label}') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_hmb_prevalence.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def plot_disruption_rate(raw, years_monthly): - """Monthly school disruption rate: one panel per hIUD scenario.""" - n_hiud = len(HIUD_SCENARIOS) - fig, axes = plt.subplots(1, n_hiud, figsize=(7 * n_hiud, 6)) - fig.suptitle('Monthly school disruption rate among in-school menstruating AGYW', fontsize=14) - - for idx, hiud_label in enumerate(HIUD_SCENARIOS): - ax = axes[idx] - - # Baseline - all_base = [] - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - all_base.extend(raw[key]['baseline_disrupted']) - base_arr = np.array(all_base) * 100 - base_mean = base_arr.mean(axis=0) - base_std = base_arr.std(axis=0) - - ax.plot(years_monthly, base_mean, color='#6c757d', lw=1.5, ls='--', - label='No intervention') - ax.fill_between(years_monthly, base_mean - base_std, base_mean + base_std, - color='#6c757d', alpha=0.15) - - for care_label in CARE_SCENARIOS: - key = scenario_key(care_label, hiud_label) - casc_arr = np.array(raw[key]['cascade_disrupted']) * 100 - casc_mean = casc_arr.mean(axis=0) - casc_std = casc_arr.std(axis=0) - - ax.plot(years_monthly, casc_mean, color=CARE_COLORS[care_label], lw=1.5, - label=CARE_LABELS[care_label]) - ax.fill_between(years_monthly, casc_mean - casc_std, casc_mean + casc_std, - color=CARE_COLORS[care_label], alpha=0.15) - - ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) - ax.set_xlabel('Year') - ax.set_ylabel('Disruption rate (%)') - ax.set_title(f'hIUD uptake target: {hiud_label}') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_disruption_rate.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def compute_disruption_stats(raw): - """ - Compute % reduction in annual disruptions, per seed then aggregated. - - % reduction = (baseline_annual - cascade_annual) / baseline_annual * 100 - """ - stats = {} - for care_label in CARE_SCENARIOS: - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - - base_annual = np.array([_annualize(m) for m in raw[key]['baseline_n_disruptions']]) - casc_annual = np.array([_annualize(m) for m in raw[key]['cascade_n_disruptions']]) - averted = base_annual - casc_annual - - pct = np.where(base_annual > 0, averted / base_annual * 100, np.nan) - - stats[key] = { - 'mean': np.nanmean(pct, axis=0), - 'std': np.nanstd(pct, axis=0), - } - return stats - - -def plot_disruption_reduction_timeseries(disruption_stats, years): - """% reduction in annual disruptions over time: all 9 scenarios.""" - post_mask = years >= INTV_YEAR - fig, ax = plt.subplots(figsize=(12, 7)) - - for care_label in CARE_SCENARIOS: - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - s = disruption_stats[key] - - mean = np.where(post_mask, s['mean'], np.nan) - upper = np.where(post_mask, s['mean'] + s['std'], np.nan) - lower = np.where(post_mask, s['mean'] - s['std'], np.nan) - - ax.plot(years, mean, color=CARE_COLORS[care_label], - ls=HIUD_LINESTYLES[hiud_label], lw=2.5, - label=f'{CARE_LABELS[care_label]}, hIUD {hiud_label}') - ax.fill_between(years, lower, upper, - color=CARE_COLORS[care_label], alpha=0.06) - - ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) - ylim = ax.get_ylim() - ax.text(INTV_YEAR - 0.2, ylim[1] * 0.95, 'Start of\nintervention', - ha='right', va='top', fontsize=9, color='#4d4d4d') - - ax.set_xlabel('Year') - ax.set_ylabel('% reduction in annual school disruptions') - ax.set_title('% reduction in school disruptions\nby care-seeking and hIUD uptake target') - ax.set_xlim([START, STOP]) - ax.set_ylim(bottom=0) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=8, ncol=3) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_disruption_reduction_timeseries.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - - -def plot_disruption_reduction_barchart(raw, disruption_stats, years): - """Bar chart: % reduction in annual disruptions, post-intervention average.""" - post_mask = years >= INTV_YEAR - - fig, axes = plt.subplots(1, 2, figsize=(18, 7)) - fig.suptitle('School disruption reduction vs. no-intervention baseline\n' - '(post-intervention average)', fontsize=14) - - care_list = list(CARE_SCENARIOS.keys()) - hiud_list = list(HIUD_SCENARIOS.keys()) - n_care = len(care_list) - n_hiud = len(hiud_list) - - x = np.arange(n_care) - total_width = 0.75 - bar_width = total_width / n_hiud - - hiud_alphas = {'5%': 0.9, '10%': 0.65, '15%': 0.4} - - # Panel 1: % reduction - ax = axes[0] - for i, hiud_label in enumerate(hiud_list): - means = [] - stds = [] - for care_label in care_list: - key = scenario_key(care_label, hiud_label) - s = disruption_stats[key] - means.append(np.nanmean(s['mean'][post_mask])) - stds.append(np.nanmean(s['std'][post_mask])) - - offset = (i - (n_hiud - 1) / 2) * bar_width - bars = ax.bar(x + offset, means, bar_width, yerr=stds, capsize=4, - label=f'hIUD {hiud_label}', - alpha=hiud_alphas[hiud_label], - edgecolor='black', linewidth=0.5, - color=[CARE_COLORS[c] for c in care_list]) - - for bar, val, err in zip(bars, means, stds): - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() + err + 0.3, - f'{val:.1f}%', ha='center', va='bottom', - fontsize=8, fontweight='bold') - - ax.set_xlabel('Care-seeking scenario') - ax.set_ylabel('% reduction in annual school disruptions') - ax.set_title('Percentage reduction') - ax.set_xticks(x) - ax.set_xticklabels([CARE_LABELS[c] for c in care_list]) - ax.legend(frameon=False, fontsize=10) - ax.grid(axis='y', alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - # Panel 2: Absolute averted disruptions - ax = axes[1] - for i, hiud_label in enumerate(hiud_list): - means = [] - stds = [] - for care_label in care_list: - key = scenario_key(care_label, hiud_label) - base_annual = np.array([_annualize(m) for m in raw[key]['baseline_n_disruptions']]) - casc_annual = np.array([_annualize(m) for m in raw[key]['cascade_n_disruptions']]) - averted = base_annual - casc_annual - post_averted = averted[:, post_mask] - means.append(post_averted.mean()) - stds.append(post_averted.std(axis=0).mean()) - - offset = (i - (n_hiud - 1) / 2) * bar_width - bars = ax.bar(x + offset, [v / 1e6 for v in means], bar_width, - yerr=[v / 1e6 for v in stds], capsize=4, - label=f'hIUD {hiud_label}', - alpha=hiud_alphas[hiud_label], - edgecolor='black', linewidth=0.5, - color=[CARE_COLORS[c] for c in care_list]) - - for bar, val, err in zip(bars, means, stds): - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() + err / 1e6 + 0.01, - f'{val/1e6:.2f}m', ha='center', va='bottom', - fontsize=8, fontweight='bold') - - ax.set_xlabel('Care-seeking scenario') - ax.set_ylabel('Annual averted disruptions (millions)') - ax.set_title('Absolute reduction') - ax.set_xticks(x) - ax.set_xticklabels([CARE_LABELS[c] for c in care_list]) - ax.legend(frameon=False, fontsize=10) - ax.grid(axis='y', alpha=0.3) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - plt.tight_layout() - outpath = PLOTFOLDER + 'sa_disruption_reduction_barchart.png' - fig.savefig(outpath, dpi=300, bbox_inches='tight') - print(f"Saved: {outpath}") - return fig - -# ── Summary tables ───────────────────────────────────────────────────────────── -def print_summary(raw, stats, years): - post_mask = years >= INTV_YEAR - - print(f"\n{'═'*90}") - print(f" SENSITIVITY ANALYSIS: Care-seeking × hIUD uptake (3×3)") - print(f" NSAID/TXA/Pill acceptance = {int(FIXED_ACCEPT*100)}%") - print(f" Post-intervention average ({INTV_YEAR}–{STOP})") - print(f"{'═'*90}") - - # % reduction table - print(f"\n % Reduction in HMB-related anemia:") - print(f" {'':>12}", end="") - for hiud_label in HIUD_SCENARIOS: - print(f" {'hIUD ' + hiud_label:>18}", end="") - print() - print(f" {'─'*66}") - - for care_label in CARE_SCENARIOS: - print(f" {CARE_LABELS[care_label]:>12}", end="") - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - s = stats[key] - m = np.nanmean(s['mean'][post_mask]) - std = np.nanmean(s['std'][post_mask]) - print(f" {m:>9.1f}% ± {std:.1f}%", end="") - print() - - # hIUD uptake table - print(f"\n % of women with underlying HMB who ever used hIUD:") - print(f" {'':>12}", end="") - for hiud_label in HIUD_SCENARIOS: - print(f" {'hIUD ' + hiud_label:>18}", end="") - print() - print(f" {'─'*66}") - - for care_label in CARE_SCENARIOS: - print(f" {CARE_LABELS[care_label]:>12}", end="") - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - vals = [u['pct_of_hmb'] for u in raw[key]['hiud_uptake']] - print(f" {np.mean(vals):>9.1f}% ± {np.std(vals):.1f}%", end="") - print() - - # hIUD uptake among seekers - print(f"\n % of HMB care-seekers who initiated hIUD:") - print(f" {'':>12}", end="") - for hiud_label in HIUD_SCENARIOS: - print(f" {'hIUD ' + hiud_label:>18}", end="") - print() - print(f" {'─'*66}") - - for care_label in CARE_SCENARIOS: - print(f" {CARE_LABELS[care_label]:>12}", end="") - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - vals = [u['pct_of_hmb_seekers'] for u in raw[key]['hiud_uptake']] - print(f" {np.mean(vals):>9.1f}% ± {np.std(vals):.1f}%", end="") - print() - - # Absolute averted table - print(f"\n Annual averted HMB-related anemia cases (millions):") - print(f" {'':>12}", end="") - for hiud_label in HIUD_SCENARIOS: - print(f" {'hIUD ' + hiud_label:>18}", end="") - print() - print(f" {'─'*66}") - - for care_label in CARE_SCENARIOS: - print(f" {CARE_LABELS[care_label]:>12}", end="") - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - averted_arr = np.array(raw[key]['averted']) - post_averted = averted_arr[:, post_mask] - m = post_averted.mean() / 1e6 - std = post_averted.std(axis=0).mean() / 1e6 - print(f" {m:>9.2f}m ± {std:.2f}m", end="") - print() - - print(f"{'═'*90}\n") - - -def print_disruption_summary(disruption_stats, years): - """Print disruption reduction summary table.""" - post_mask = years >= INTV_YEAR - - print(f"\n % Reduction in annual school disruptions:") - print(f" {'':>12}", end="") - for hiud_label in HIUD_SCENARIOS: - print(f" {'hIUD ' + hiud_label:>18}", end="") - print() - print(f" {'─'*66}") - - for care_label in CARE_SCENARIOS: - print(f" {CARE_LABELS[care_label]:>12}", end="") - for hiud_label in HIUD_SCENARIOS: - key = scenario_key(care_label, hiud_label) - s = disruption_stats[key] - m = np.nanmean(s['mean'][post_mask]) - std = np.nanmean(s['std'][post_mask]) - print(f" {m:>9.1f}% ± {std:.1f}%", end="") - print() - - -# ── Main ─────────────────────────────────────────────────────────────────────── -if __name__ == '__main__': - - do_run = True - - raw = run_sa(force_rerun=do_run) - - # Build axes - first_key = list(raw.keys())[0] - n_years = len(raw[first_key]['baseline'][0]) - years = np.arange(START, START + n_years) - - n_months = len(raw[first_key]['baseline_monthly'][0]) - years_monthly = np.array([START + m / 12 for m in range(n_months)]) - - stats = compute_stats(raw) - - # Compute disruption stats - disruption_stats = compute_disruption_stats(raw) - - # Plots - plot_monthly_cases(raw, years_monthly) - plot_impact_comparison(raw, years_monthly) - plot_pct_reduction_timeseries(stats, years) - plot_pct_reduction_barchart(raw, stats, years) - plot_hiud_uptake(raw) - plot_cascade_comparison(raw) - plot_hmb_prevalence(raw, years_monthly) - plot_disruption_rate(raw, years_monthly) - plot_disruption_reduction_timeseries(disruption_stats, years) - plot_disruption_reduction_barchart(raw, disruption_stats, years) - - # Summary - print_summary(raw, stats, years) - print_disruption_summary(disruption_stats, years) \ No newline at end of file diff --git a/run_cascade.py b/run_cascade.py index 04ac7a2..5aa4649 100644 --- a/run_cascade.py +++ b/run_cascade.py @@ -21,7 +21,7 @@ # Simulation creation functions # ============================================================================ -def make_base_sim(seed=0): +def make_base_sim(seed=10): """ Create baseline simulation without intervention @@ -35,7 +35,7 @@ def make_base_sim(seed=0): sim = fp.Sim( start=2020, stop=2030, - n_agents=5000, + n_agents=10000, total_pop=55_000_000, # Kenya's population for scaling location='kenya', education_module=edu, @@ -47,7 +47,7 @@ def make_base_sim(seed=0): return sim -def make_intervention_sim(seed=0): +def make_intervention_sim(seed=10): """ Create simulation with HMB care pathway intervention diff --git a/run_scenarios.py b/run_scenarios.py new file mode 100644 index 0000000..686e55e --- /dev/null +++ b/run_scenarios.py @@ -0,0 +1,649 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Apr 24 15:18:22 2026 + +@author: navidehno +""" + +""" +Scenario comparison: Darcy's 3×3 intervention packages vs status quo. + +Care-seeking levels (rows): + Base: 10% ever-seek, 1/36 monthly first + Mid: 20% ever-seek, 1/12 monthly first + High: 35% ever-seek, 1/6 monthly first + +hIUD uptake levels (columns): + Low: 10% of seekers get hIUD, 73% offered treatment + Mid: 25% of seekers get hIUD, 80% offered treatment + High: 40% of seekers get hIUD, 90% offered treatment + +tx_weights = % of those OFFERED treatment (sums to ~100). +prob_offer = global gate for % who get any treatment. +Together: prob_offer × tx_weight_i = % of seekers getting treatment i. + +All use mid RR (1.73) throughout. +Pre-2026: all sims run status quo (10% care, 70% offer, no hIUD). +Post-2026: each sim switches to its scenario. + +Outcomes: HMB prevalence and HMB-related anemia. +""" + +import numpy as np +import sciris as sc +import starsim as ss +import fpsim as fp +import os +import gc +import matplotlib.pyplot as plt + +from menstruation import Menstruation +from education import Education +from interventions_pool import HMBPool, HMBCounterfactual +from analyzers import track_hmb_anemia + + +# ── Output folders ───────────────────────────────────────────────────────────── +PLOTFOLDER = 'figures_scenarios/' +OUTFOLDER = 'results_scenarios/' +for d in [PLOTFOLDER, OUTFOLDER]: + os.makedirs(d, exist_ok=True) + + +# ── Settings ─────────────────────────────────────────────────────────────────── +P_BASE = 0.215 +P_HMB_PRONE = 0.53 +RR_MID = 1.73 +N_SEEDS = 10 +START = 2020 +STOP = 2030 +INTV_YEAR = 2026 + + +# ── Care-seeking levels ──────────────────────────────────────────────────────── + +CARE_PRE = sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, +) + +CARE_LEVELS = { + 'base': sc.objdict( + label='Base care (10%)', + care=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + ), + 'mid': sc.objdict( + label='Mid care (20%)', + care=sc.objdict( + p_ever_seek=0.20, + p_monthly_first=1/12, + p_monthly_first_anemic=1/6, + p_monthly_first_pain=1/12, + p_monthly_repeat=1/6, + ), + ), + 'high': sc.objdict( + label='High care (35%)', + care=sc.objdict( + p_ever_seek=0.35, + p_monthly_first=1/6, + p_monthly_first_anemic=1/6, + p_monthly_first_pain=1/6, + p_monthly_repeat=1/6, + ), + ), +} + +# ── hIUD uptake levels ───────────────────────────────────────────────────────── +# tx_weights: % of those OFFERED treatment (sums to ~100) +# Derived from Darcy's table: raw % of seekers / prob_offer + +HIUD_LEVELS = { + 'low_hiud': sc.objdict( + label='Low hIUD (10%)', + prob_offer=0.73, + tx_weights=sc.objdict(nsaid=43.2, txa=21.6, pill=21.6, hiud=13.7), + ), + 'mid_hiud': sc.objdict( + label='Mid hIUD (25%)', + prob_offer=0.80, + tx_weights=sc.objdict(nsaid=34.4, txa=17.2, pill=17.2, hiud=31.3), + ), + 'high_hiud': sc.objdict( + label='High hIUD (40%)', + prob_offer=0.90, + tx_weights=sc.objdict(nsaid=27.8, txa=13.9, pill=13.9, hiud=44.4), + ), +} + +# ── Colors ───────────────────────────────────────────────────────────────────── +# 3×3 color matrix: rows=care, columns=hIUD, light→dark within each row +CARE_COLORS = { + 'base': {'low_hiud': '#BBDEFB', 'mid_hiud': '#64B5F6', 'high_hiud': '#1565C0'}, + 'mid': {'low_hiud': '#C8E6C9', 'mid_hiud': '#66BB6A', 'high_hiud': '#2E7D32'}, + 'high': {'low_hiud': '#FFCDD2', 'mid_hiud': '#EF5350', 'high_hiud': '#B71C1C'}, +} + +# For plots paneled by hIUD level, color by care level +CARE_LINE_COLORS = {'base': '#2196F3', 'mid': '#4CAF50', 'high': '#F44336'} + +# For plots paneled by care level, color by hIUD level +HIUD_LINE_COLORS = {'low_hiud': '#90CAF9', 'mid_hiud': '#4CAF50', 'high_hiud': '#F44336'} + + +# ── Helpers ──────────────────────────────────────────────────────────────────── +def rr_to_logistic_coeff(rr, p_base=P_BASE): + p_hmb = np.clip(p_base * rr, 1e-6, 1 - 1e-6) + return (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) + + +def make_menstruation(): + coeff = rr_to_logistic_coeff(RR_MID) + return Menstruation(pars={ + 'p_hmb_prone': ss.bernoulli(p=P_HMB_PRONE), + 'hmb_seq': sc.objdict( + poor_mh=sc.objdict(base=0.4, hmb=1.0), + anemic=sc.objdict(base=P_BASE, hmb=coeff), + pain=sc.objdict(base=0.1, hmb=3.36), + ) + }) + + +def make_intervention(care_name, hiud_name): + """ + Build intervention for a care × hIUD combination. + Pre-2026: status quo. Post-2026: scenario-specific. + """ + care_info = CARE_LEVELS[care_name] + hiud_info = HIUD_LEVELS[hiud_name] + + return HMBPool(pars=dict( + year=2020, + intv_year=INTV_YEAR, + + care_behavior_pre=CARE_PRE, + care_behavior_post=care_info.care, + + prob_offer_pre=0.70, + prob_offer_post=hiud_info.prob_offer, + + tx_weights_pre=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + tx_weights_post=hiud_info.tx_weights, + + nsaid=sc.objdict(efficacy=0.33, adherence=0.80), + txa=sc.objdict(efficacy=0.45, adherence=0.70), + pill=sc.objdict(efficacy=0.59, adherence=0.80), + hiud=sc.objdict(efficacy=0.88, adherence=1.00), + )) + + +def make_counterfactual(): + return HMBCounterfactual(pars=dict( + care_behavior_pre=CARE_PRE, + care_behavior_post=CARE_PRE, + )) + + +def make_sim(scenario, seed=0): + """ + scenario: 'counterfactual' or tuple (care_name, hiud_name) + """ + mens = make_menstruation() + edu = Education() + + if scenario == 'counterfactual': + intervention = make_counterfactual() + else: + care_name, hiud_name = scenario + intervention = make_intervention(care_name, hiud_name) + + return fp.Sim( + start=START, stop=STOP, + n_agents=10000, total_pop=55_000_000, + location='kenya', + education_module=edu, + connectors=[mens], + interventions=[intervention], + analyzers=[track_hmb_anemia()], + rand_seed=seed, verbose=0, + ) + + +# ── Run ──────────────────────────────────────────────────────────────────────── +def _annualize(monthly_arr, how='sum'): + arr = np.asarray(monthly_arr) + n_y = len(arr) // 12 + arr = arr[:12 * n_y].reshape(n_y, 12) + return arr.sum(axis=1) if how == 'sum' else arr[:, -1] + + +def _annualize_mean(monthly_arr): + arr = np.asarray(monthly_arr) + n_y = len(arr) // 12 + arr = arr[:12 * n_y].reshape(n_y, 12) + return arr.mean(axis=1) + + +def run_scenarios(force_rerun=True): + results_file = OUTFOLDER + 'scenario_3x3_raw.obj' + + if not force_rerun and os.path.exists(results_file): + print("Loading saved results...") + return sc.loadobj(results_file) + + # Build scenario list: counterfactual + 3×3 + all_scenarios = ['counterfactual'] + for care_name in CARE_LEVELS: + for hiud_name in HIUD_LEVELS: + all_scenarios.append((care_name, hiud_name)) + + raw = {} + for scen in all_scenarios: + key = scen if scen == 'counterfactual' else f'{scen[0]}_{scen[1]}' + raw[key] = { + 'hmb_prev_monthly': [], + 'anemia_hmb_monthly': [], + 'anemia_total_monthly': [], + 'hmb_prev_annual': [], + 'anemia_hmb_annual': [], + 'anemia_total_annual': [], + } + + for scen in all_scenarios: + key = scen if scen == 'counterfactual' else f'{scen[0]}_{scen[1]}' + if scen == 'counterfactual': + label = 'Status quo (10% care, no hIUD)' + else: + label = f'{CARE_LEVELS[scen[0]].label} + {HIUD_LEVELS[scen[1]].label}' + + print(f"\n{'='*60}") + print(f" {label}") + print(f"{'='*60}") + + for seed in range(N_SEEDS): + print(f" seed {seed}...", end=" ", flush=True) + sim = make_sim(scen, seed=seed) + sim.run() + + hmb_prev = np.asarray(sim.results.menstruation['hmb_prev']) + anemia_hmb = np.asarray(sim.results.track_hmb_anemia['n_anemia_with_hmb']) + anemia_total = np.asarray(sim.results.track_hmb_anemia['n_anemia_total']) + + raw[key]['hmb_prev_monthly'].append(hmb_prev) + raw[key]['anemia_hmb_monthly'].append(anemia_hmb) + raw[key]['anemia_total_monthly'].append(anemia_total) + raw[key]['hmb_prev_annual'].append(_annualize_mean(hmb_prev)) + raw[key]['anemia_hmb_annual'].append(_annualize(anemia_hmb)) + raw[key]['anemia_total_annual'].append(_annualize(anemia_total)) + + del sim; gc.collect() + print("done") + + sc.saveobj(results_file, raw) + print(f"\nSaved: {results_file}") + return raw + + +# ── Statistics ───────────────────────────────────────────────────────────────── +def compute_pct_reduction(raw, scen_key, outcome='anemia_hmb_annual'): + cf = np.array(raw['counterfactual'][outcome]) + intv = np.array(raw[scen_key][outcome]) + averted = cf - intv + pct = np.where(cf > 0, averted / cf * 100, np.nan) + return { + 'mean': np.nanmean(pct, axis=0), + 'lower': np.nanpercentile(pct, 2.5, axis=0), + 'upper': np.nanpercentile(pct, 97.5, axis=0), + } + + +def compute_abs_reduction(raw, scen_key, outcome='anemia_hmb_annual'): + cf = np.array(raw['counterfactual'][outcome]) + intv = np.array(raw[scen_key][outcome]) + diff = cf - intv + return { + 'mean': np.mean(diff, axis=0), + 'lower': np.percentile(diff, 2.5, axis=0), + 'upper': np.percentile(diff, 97.5, axis=0), + } + + +# ── Plots ────────────────────────────────────────────────────────────────────── + +def plot_panels_by_care(raw, years_monthly, outcome, ylabel, title_prefix, filename): + """3 panels (base/mid/high care). Per panel: counterfactual + 3 hIUD lines.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle(f'{title_prefix}\n' + f'Status quo before {INTV_YEAR}, intervention after. Mid RR ({RR_MID})', + fontsize=13) + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + + cf_arr = np.array(raw['counterfactual'][outcome]) + ax.plot(years_monthly, cf_arr.mean(axis=0), color='#6c757d', ls='--', lw=2, + label='Status quo') + + for hiud_name, hiud_info in HIUD_LEVELS.items(): + key = f'{care_name}_{hiud_name}' + arr = np.array(raw[key][outcome]) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + ax.plot(years_monthly, mean, color=CARE_COLORS[care_name][hiud_name], + lw=1.8, label=hiud_info.label) + ax.fill_between(years_monthly, mean - std, mean + std, + color=CARE_COLORS[care_name][hiud_name], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel(ylabel) + ax.set_title(care_info.label) + ax.set_xlim([START, STOP]) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8) + if 'anemia' in outcome: + ax.set_ylim(bottom=0) + sc.SIticks(ax=ax) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + filename, dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}{filename}") + return fig + + +def plot_pct_reduction_by_care(raw, years, outcome, ylabel, title_prefix, filename): + """3 panels (base/mid/high care). Per panel: 3 hIUD lines showing % reduction.""" + post_mask = years >= INTV_YEAR + + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle(f'{title_prefix}\n' + f'% reduction vs status quo. Mid RR ({RR_MID})', fontsize=13) + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + + for hiud_name, hiud_info in HIUD_LEVELS.items(): + key = f'{care_name}_{hiud_name}' + s = compute_pct_reduction(raw, key, outcome=outcome) + mean = np.where(post_mask, s['mean'], np.nan) + lower = np.where(post_mask, s['lower'], np.nan) + upper = np.where(post_mask, s['upper'], np.nan) + + ax.plot(years, mean, color=HIUD_LINE_COLORS[hiud_name], lw=2.5, + label=hiud_info.label) + ax.fill_between(years, lower, upper, + color=HIUD_LINE_COLORS[hiud_name], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel(ylabel) + ax.set_title(care_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + filename, dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}{filename}") + return fig + + +def plot_pct_reduction_by_hiud(raw, years, outcome, ylabel, title_prefix, filename): + """3 panels (low/mid/high hIUD). Per panel: 3 care lines.""" + post_mask = years >= INTV_YEAR + + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle(f'{title_prefix}\n' + f'% reduction vs status quo. Mid RR ({RR_MID})', fontsize=13) + + for idx, (hiud_name, hiud_info) in enumerate(HIUD_LEVELS.items()): + ax = axes[idx] + + for care_name, care_info in CARE_LEVELS.items(): + key = f'{care_name}_{hiud_name}' + s = compute_pct_reduction(raw, key, outcome=outcome) + mean = np.where(post_mask, s['mean'], np.nan) + lower = np.where(post_mask, s['lower'], np.nan) + upper = np.where(post_mask, s['upper'], np.nan) + + ax.plot(years, mean, color=CARE_LINE_COLORS[care_name], lw=2.5, + label=care_info.label) + ax.fill_between(years, lower, upper, + color=CARE_LINE_COLORS[care_name], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel(ylabel) + ax.set_title(hiud_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + filename, dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}{filename}") + return fig + + +def plot_heatmap(raw, years, outcome, title, filename, fmt='.1f'): + """3×3 heatmap: % reduction at end of sim.""" + post_mask = years >= INTV_YEAR + care_names = list(CARE_LEVELS.keys()) + hiud_names = list(HIUD_LEVELS.keys()) + + matrix = np.zeros((len(care_names), len(hiud_names))) + for i, care_name in enumerate(care_names): + for j, hiud_name in enumerate(hiud_names): + key = f'{care_name}_{hiud_name}' + s = compute_pct_reduction(raw, key, outcome=outcome) + matrix[i, j] = s['mean'][-1] + + fig, ax = plt.subplots(figsize=(8, 5)) + im = ax.imshow(matrix, cmap='YlOrRd', aspect='auto') + + ax.set_xticks(range(len(hiud_names))) + ax.set_xticklabels([HIUD_LEVELS[h].label for h in hiud_names], fontsize=10) + ax.set_yticks(range(len(care_names))) + ax.set_yticklabels([CARE_LEVELS[c].label for c in care_names], fontsize=10) + + for i in range(len(care_names)): + for j in range(len(hiud_names)): + ax.text(j, i, f'{matrix[i, j]:{fmt}}%', + ha='center', va='center', fontsize=12, fontweight='bold', + color='white' if matrix[i, j] > matrix.max() * 0.5 else 'black') + + ax.set_xlabel('hIUD uptake level') + ax.set_ylabel('Care-seeking level') + ax.set_title(f'{title}\n% reduction vs status quo at {STOP} (mid RR = {RR_MID})') + plt.colorbar(im, ax=ax, label='% reduction', shrink=0.8) + plt.tight_layout() + fig.savefig(PLOTFOLDER + filename, dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}{filename}") + return fig + + +def plot_combined_absolute(raw, years_monthly): + """Side-by-side: HMB prevalence and HMB-related anemia, all 10 lines.""" + fig, axes = plt.subplots(1, 2, figsize=(18, 6)) + fig.suptitle(f'HMB prevalence and HMB-related anemia by scenario\n' + f'Mid RR ({RR_MID})', fontsize=14) + + panels = [ + ('hmb_prev_monthly', 'HMB prevalence', 'HMB prevalence'), + ('anemia_hmb_monthly', 'HMB-related anemia', 'Monthly anemia cases (HMB women)'), + ] + + for idx, (outcome, title, ylabel) in enumerate(panels): + ax = axes[idx] + + cf = np.array(raw['counterfactual'][outcome]) + ax.plot(years_monthly, cf.mean(axis=0), color='#6c757d', ls='--', lw=2.5, + label='Status quo') + + for care_name in CARE_LEVELS: + for hiud_name in HIUD_LEVELS: + key = f'{care_name}_{hiud_name}' + arr = np.array(raw[key][outcome]) + label = f'{CARE_LEVELS[care_name].label} + {HIUD_LEVELS[hiud_name].label}' + ax.plot(years_monthly, arr.mean(axis=0), + color=CARE_COLORS[care_name][hiud_name], lw=1.2, + label=label) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year'); ax.set_ylabel(ylabel); ax.set_title(title) + ax.set_xlim([START, STOP]) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + if idx == 1: + ax.set_ylim(bottom=0); sc.SIticks(ax=ax) + ax.legend(frameon=False, fontsize=6, ncol=2) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'absolute_combined_all.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}absolute_combined_all.png") + return fig + + +# ── Summary table ────────────────────────────────────────────────────────────── +def print_summary(raw, years): + post_mask = years >= INTV_YEAR + + print(f"\n{'═'*100}") + print(f" Scenario comparison vs status quo (mid RR = {RR_MID})") + print(f" Post-intervention period: {INTV_YEAR}–{STOP}") + print(f"{'═'*100}") + + # End-of-sim values + print(f"\n End-of-sim values:") + print(f" {'Scenario':<45} {'HMB prev':>10} {'HMB anemia (M)':>16}") + print(f" {'─'*75}") + + # Counterfactual + hmb = np.array(raw['counterfactual']['hmb_prev_annual'])[:, -1] + ane = np.array(raw['counterfactual']['anemia_hmb_annual'])[:, -1] + print(f" {'Status quo (10% care, no hIUD)':<45} " + f"{hmb.mean():.3f}±{hmb.std():.3f} " + f"{ane.mean()/1e6:>8.2f}±{ane.std()/1e6:.2f}M") + + for care_name, care_info in CARE_LEVELS.items(): + for hiud_name, hiud_info in HIUD_LEVELS.items(): + key = f'{care_name}_{hiud_name}' + label = f'{care_info.label} + {hiud_info.label}' + hmb = np.array(raw[key]['hmb_prev_annual'])[:, -1] + ane = np.array(raw[key]['anemia_hmb_annual'])[:, -1] + print(f" {label:<45} " + f"{hmb.mean():.3f}±{hmb.std():.3f} " + f"{ane.mean()/1e6:>8.2f}±{ane.std()/1e6:.2f}M") + + # % reduction matrices + for outcome, title in [('hmb_prev_annual', 'HMB prevalence'), + ('anemia_hmb_annual', 'HMB-related anemia')]: + print(f"\n % reduction in {title} vs status quo (mean over {INTV_YEAR}–{STOP}):") + print(f" {'':>20}", end="") + for hiud_name, hiud_info in HIUD_LEVELS.items(): + print(f" {hiud_info.label:>22}", end="") + print() + print(f" {'─'*90}") + + for care_name, care_info in CARE_LEVELS.items(): + print(f" {care_info.label:<20}", end="") + for hiud_name in HIUD_LEVELS: + key = f'{care_name}_{hiud_name}' + s = compute_pct_reduction(raw, key, outcome=outcome) + m = np.nanmean(s['mean'][post_mask]) + lo = np.nanmean(s['lower'][post_mask]) + hi = np.nanmean(s['upper'][post_mask]) + print(f" {m:>5.1f}% ({lo:.1f}–{hi:.1f})", end="") + print() + + # Absolute averted at end of sim + print(f"\n Averted HMB-anemia cases at end of sim (millions):") + print(f" {'':>20}", end="") + for hiud_name, hiud_info in HIUD_LEVELS.items(): + print(f" {hiud_info.label:>22}", end="") + print() + print(f" {'─'*90}") + + for care_name, care_info in CARE_LEVELS.items(): + print(f" {care_info.label:<20}", end="") + for hiud_name in HIUD_LEVELS: + key = f'{care_name}_{hiud_name}' + s = compute_abs_reduction(raw, key, 'anemia_hmb_annual') + m = s['mean'][-1] / 1e6 + print(f" {m:>18.2f}M", end="") + print() + + print(f"\n{'═'*100}\n") + + +# ── Main ─────────────────────────────────────────────────────────────────────── +if __name__ == '__main__': + + do_run = True + raw = run_scenarios(force_rerun=do_run) + + # Time axes + n_years = len(raw['counterfactual']['hmb_prev_annual'][0]) + years_full = np.arange(START, START + n_years) + + n_months = len(raw['counterfactual']['hmb_prev_monthly'][0]) + years_monthly = np.array([START + m / 12 for m in range(n_months)]) + + # ── Absolute plots by care level ── + plot_panels_by_care(raw, years_monthly, + outcome='hmb_prev_monthly', ylabel='HMB prevalence', + title_prefix='HMB prevalence by scenario', + filename='hmb_prev_by_care.png') + + plot_panels_by_care(raw, years_monthly, + outcome='anemia_hmb_monthly', ylabel='Monthly anemia cases (HMB women)', + title_prefix='HMB-related anemia by scenario', + filename='anemia_hmb_by_care.png') + + # ── % reduction by care level ── + plot_pct_reduction_by_care(raw, years_full, + outcome='hmb_prev_annual', ylabel='% reduction in HMB prevalence', + title_prefix='HMB prevalence reduction', + filename='pct_reduction_hmb_by_care.png') + + plot_pct_reduction_by_care(raw, years_full, + outcome='anemia_hmb_annual', ylabel='% reduction in HMB-related anemia', + title_prefix='HMB-related anemia reduction', + filename='pct_reduction_anemia_by_care.png') + + # ── % reduction by hIUD level ── + plot_pct_reduction_by_hiud(raw, years_full, + outcome='hmb_prev_annual', ylabel='% reduction in HMB prevalence', + title_prefix='HMB prevalence reduction', + filename='pct_reduction_hmb_by_hiud.png') + + plot_pct_reduction_by_hiud(raw, years_full, + outcome='anemia_hmb_annual', ylabel='% reduction in HMB-related anemia', + title_prefix='HMB-related anemia reduction', + filename='pct_reduction_anemia_by_hiud.png') + + # ── Combined absolute ── + plot_combined_absolute(raw, years_monthly) + + # ── Heatmaps ── + plot_heatmap(raw, years_full, + outcome='hmb_prev_annual', title='HMB prevalence', + filename='heatmap_hmb_prev.png') + + plot_heatmap(raw, years_full, + outcome='anemia_hmb_annual', title='HMB-related anemia', + filename='heatmap_anemia_hmb.png') + + # ── Summary ── + print_summary(raw, years_full) \ No newline at end of file diff --git a/stats_interventions2.py b/stats_interventions2.py new file mode 100644 index 0000000..7a98ca8 --- /dev/null +++ b/stats_interventions2.py @@ -0,0 +1,1019 @@ +""" +Treatment usage and care-seeking analysis for HMB pool model. + +Metrics tracked: + - Monthly % of HMB women seeking care (by condition: anemia/pain/base) + - Monthly % on each treatment type + - Person-months of care-seeking and treatment + - Distribution of care visits per woman + - Distribution of treatment duration per woman + - Gave-up and hysterectomy counts + +Uses same 3×3 scenario grid as run_scenarios.py: + Care: base (10%), mid (20%), high (35%) + hIUD: low (10%), mid (25%), high (40%) +""" + +import numpy as np +import sciris as sc +import starsim as ss +import fpsim as fp +import os +import gc +import matplotlib.pyplot as plt +from collections import defaultdict + +from menstruation import Menstruation +from education import Education +from interventions_pool import HMBPool, HMBCounterfactual +from analyzers import track_hmb_anemia + + +# ── Output folders ───────────────────────────────────────────────────────────── +PLOTFOLDER = 'figures_treatment_usage/' +OUTFOLDER = 'results_treatment_usage/' +for d in [PLOTFOLDER, OUTFOLDER]: + os.makedirs(d, exist_ok=True) + + +# ── Settings ─────────────────────────────────────────────────────────────────── +P_BASE = 0.215 +P_HMB_PRONE = 0.53 +RR_MID = 1.73 +N_SEEDS = 10 +START = 2020 +STOP = 2030 +INTV_YEAR = 2026 + + +# ── Care-seeking levels (same as run_scenarios.py) ───────────────────────────── + +CARE_PRE = sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, +) + +CARE_LEVELS = { + 'base': sc.objdict( + label='Base care (10%)', + care=sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, + ), + ), + 'mid': sc.objdict( + label='Mid care (20%)', + care=sc.objdict( + p_ever_seek=0.20, + p_monthly_first=1/12, + p_monthly_first_anemic=1/6, + p_monthly_first_pain=1/12, + p_monthly_repeat=1/6, + ), + ), + 'high': sc.objdict( + label='High care (35%)', + care=sc.objdict( + p_ever_seek=0.35, + p_monthly_first=1/6, + p_monthly_first_anemic=1/6, + p_monthly_first_pain=1/6, + p_monthly_repeat=1/6, + ), + ), +} + +# ── hIUD uptake levels ───────────────────────────────────────────────────────── +# tx_weights: % of those OFFERED treatment (sums to ~100) + +HIUD_LEVELS = { + 'low_hiud': sc.objdict( + label='Low hIUD (10%)', + prob_offer=0.73, + tx_weights=sc.objdict(nsaid=43.2, txa=21.6, pill=21.6, hiud=13.7), + ), + 'mid_hiud': sc.objdict( + label='Mid hIUD (25%)', + prob_offer=0.80, + tx_weights=sc.objdict(nsaid=34.4, txa=17.2, pill=17.2, hiud=31.3), + ), + 'high_hiud': sc.objdict( + label='High hIUD (40%)', + prob_offer=0.90, + tx_weights=sc.objdict(nsaid=27.8, txa=13.9, pill=13.9, hiud=44.4), + ), +} + +# ── Colors ───────────────────────────────────────────────────────────────────── +CARE_COLORS = {'base': '#2196F3', 'mid': '#4CAF50', 'high': '#F44336'} +CARE_LABELS = {'base': 'Care 10%', 'mid': 'Care 20%', 'high': 'Care 35%'} +TX_COLORS = { + 'nsaid': '#1f77b4', 'txa': '#ff7f0e', + 'pill': '#2ca02c', 'hiud': '#d62728', +} +TX_LABELS = { + 'nsaid': 'NSAID', 'txa': 'TXA', + 'pill': 'Pill', 'hiud': 'hIUD', +} + + +# ============================================================================ +# Analyzer: detailed treatment usage tracking for pool model +# ============================================================================ + +class TrackTreatmentUsage(ss.Analyzer): + """ + Tracks care-seeking and treatment usage statistics for the pool model. + + Works with HMBPool, HMBCounterfactual, or any orchestrator that has + a .treatments dict and .on_any_treatment property. + """ + + def __init__(self, **kwargs): + super().__init__(name='track_treatment_usage') + self._care_visit_counts = None + self._prev_seeking_uids = None + self._months_seeking = None + self._months_on_nsaid = None + self._months_on_txa = None + self._months_on_pill = None + self._months_on_hiud = None + self._prev_on = None + + def _get_intervention(self): + for name in ['hmb_pool', 'hmb_counterfactual']: + if hasattr(self.sim.interventions, name): + return getattr(self.sim.interventions, name) + return None + + def init_results(self): + super().init_results() + results = [] + for key in [ + 'n_hmb_underlying', 'n_menstruating', + 'n_seeking_care', 'n_seeking_care_anemic', + 'n_seeking_care_pain', 'n_seeking_care_base', + 'n_on_nsaid', 'n_on_txa', 'n_on_pill', 'n_on_hiud', + 'n_on_any', 'n_ever_sought', + 'pct_seeking_care', 'pct_on_any', + 'pct_on_nsaid', 'pct_on_txa', 'pct_on_pill', 'pct_on_hiud', + 'pct_seeking_anemic', 'pct_seeking_pain', 'pct_seeking_base', + 'pct_seekers_w_anemia', 'pct_seekers_w_pain', 'pct_seekers_w_base', + 'pm_seeking', + 'pm_on_nsaid', 'pm_on_txa', 'pm_on_pill', 'pm_on_hiud', + 'pm_on_any', + 'n_gave_up', 'n_hysterectomy', 'pct_hysterectomy_all_women', + 'n_ever_seekers', + ]: + results.append(ss.Result(key, scale=False, label=key)) + self.define_results(*results) + + def step(self): + ti = self.ti + sim = self.sim + ppl = sim.people + scale = sim.pars.total_pop / sim.pars.n_agents + + intervention = self._get_intervention() + if intervention is None: + return + + mens = ppl.menstruation + + # ── First-call initialization ── + if self._care_visit_counts is None: + self._care_visit_counts = defaultdict(int) + self._prev_seeking_uids = set() + self._months_seeking = defaultdict(int) + self._months_on_nsaid = defaultdict(int) + self._months_on_txa = defaultdict(int) + self._months_on_pill = defaultdict(int) + self._months_on_hiud = defaultdict(int) + self._prev_on = {tx: set() for tx in ['nsaid', 'txa', 'pill', 'hiud']} + + # ── Underlying HMB pool ── + on_any = intervention.on_any_treatment + hmb_underlying = (mens.hmb | on_any) & mens.menstruating + n_hmb = np.count_nonzero(hmb_underlying) + n_mens = np.count_nonzero(mens.menstruating) + + self.results['n_hmb_underlying'][ti] = n_hmb * scale + self.results['n_menstruating'][ti] = n_mens * scale + + uids = np.asarray(ppl.uid) + + # ── Care-seeking ── + seeking = np.zeros(len(ppl), dtype=bool) + for tx in intervention.treatments.values(): + seeking |= np.asarray(tx.seeking_care) + + current_seeking_uids = set(uids[seeking]) + newly_seeking_uids = current_seeking_uids - self._prev_seeking_uids + for uid in newly_seeking_uids: + self._care_visit_counts[uid] += 1 + self._prev_seeking_uids = current_seeking_uids + + for uid in current_seeking_uids: + self._months_seeking[uid] += 1 + + n_seeking = np.count_nonzero(seeking) + seeking_anemic = seeking & np.asarray(mens.anemic) + seeking_pain = seeking & np.asarray(mens.pain) + seeking_base = seeking & ~np.asarray(mens.anemic) & ~np.asarray(mens.pain) + + self.results['n_seeking_care'][ti] = n_seeking * scale + self.results['n_seeking_care_anemic'][ti] = np.count_nonzero(seeking_anemic) * scale + self.results['n_seeking_care_pain'][ti] = np.count_nonzero(seeking_pain) * scale + self.results['n_seeking_care_base'][ti] = np.count_nonzero(seeking_base) * scale + self.results['pm_seeking'][ti] = n_seeking * scale + + # ── Treatment counts and person-months ── + tx_month_maps = { + 'nsaid': self._months_on_nsaid, + 'txa': self._months_on_txa, + 'pill': self._months_on_pill, + 'hiud': self._months_on_hiud, + } + + for tx_name in ['nsaid', 'txa', 'pill', 'hiud']: + on_tx = np.asarray(intervention.treatments[tx_name].on_treatment) + on_tx_uids = set(uids[on_tx]) + n_on = len(on_tx_uids) + self.results[f'n_on_{tx_name}'][ti] = n_on * scale + self.results[f'pm_on_{tx_name}'][ti] = n_on * scale + + for uid in on_tx_uids: + tx_month_maps[tx_name][uid] += 1 + + self._prev_on[tx_name] = on_tx_uids + + n_on_any_total = np.count_nonzero(on_any) + n_on_any_hmb = np.count_nonzero(on_any & hmb_underlying) + self.results['n_on_any'][ti] = n_on_any_total * scale + self.results['pm_on_any'][ti] = n_on_any_total * scale + + # ── Ever sought care ── + ever_sought = np.zeros(len(ppl), dtype=bool) + for tx in intervention.treatments.values(): + ever_sought |= np.asarray(tx.ever_sought_care) + self.results['n_ever_sought'][ti] = np.count_nonzero(ever_sought & hmb_underlying) * scale + + # ── Ever-seekers (Layer 1) ── + self.results['n_ever_seekers'][ti] = np.count_nonzero(intervention.ever_seeker) * scale + + # ── Percentages ── + if n_hmb > 0: + self.results['pct_seeking_care'][ti] = 100 * n_seeking / n_hmb + self.results['pct_on_any'][ti] = 100 * n_on_any_hmb / n_hmb + for tx_name in ['nsaid', 'txa', 'pill', 'hiud']: + n_on = np.count_nonzero(intervention.treatments[tx_name].on_treatment & hmb_underlying) + self.results[f'pct_on_{tx_name}'][ti] = 100 * n_on / n_hmb + self.results['pct_seeking_anemic'][ti] = 100 * np.count_nonzero(seeking_anemic) / n_hmb + self.results['pct_seeking_pain'][ti] = 100 * np.count_nonzero(seeking_pain) / n_hmb + self.results['pct_seeking_base'][ti] = 100 * np.count_nonzero(seeking_base) / n_hmb + + if n_seeking > 0: + self.results['pct_seekers_w_anemia'][ti] = 100 * np.count_nonzero(seeking_anemic) / n_seeking + self.results['pct_seekers_w_pain'][ti] = 100 * np.count_nonzero(seeking_pain) / n_seeking + self.results['pct_seekers_w_base'][ti] = 100 * np.count_nonzero(seeking_base) / n_seeking + + # ── Gave up / hysterectomy ── + self.results['n_gave_up'][ti] = np.count_nonzero(intervention.gave_up) * scale + self.results['n_hysterectomy'][ti] = np.count_nonzero(intervention.had_hysterectomy) * scale + + n_all_women = np.count_nonzero(ppl.female) + self.results['pct_hysterectomy_all_women'][ti] = ( + 100 * np.count_nonzero(intervention.had_hysterectomy) / max(n_all_women, 1) + ) + + def finalize(self): + super().finalize() + + intervention = self._get_intervention() + if intervention is None: + return + + mens = self.sim.people.menstruation + hmb_underlying = (mens.hmb | intervention.on_any_treatment) & mens.menstruating + hmb_uids = np.asarray(hmb_underlying.uids) + + # ── Care visit distribution ── + if self._care_visit_counts is not None and len(hmb_uids) > 0: + visits = np.array([self._care_visit_counts.get(uid, 0) for uid in hmb_uids], dtype=int) + self.results['care_visit_distribution'] = visits + self.results['care_visit_max'] = int(visits.max()) + self.results['care_visit_mean'] = float(visits.mean()) + else: + self.results['care_visit_distribution'] = np.array([]) + self.results['care_visit_max'] = 0 + self.results['care_visit_mean'] = 0.0 + + # ── Per-person months seeking care ── + if self._months_seeking is not None and len(hmb_uids) > 0: + ms = np.array([self._months_seeking.get(uid, 0) for uid in hmb_uids], dtype=int) + self.results['months_seeking_distribution'] = ms + else: + self.results['months_seeking_distribution'] = np.array([]) + + # ── Per-person months on each treatment ── + tx_month_maps = { + 'nsaid': self._months_on_nsaid, + 'txa': self._months_on_txa, + 'pill': self._months_on_pill, + 'hiud': self._months_on_hiud, + } + all_uids = np.asarray(self.sim.people.uid) + + for tx_name, month_map in tx_month_maps.items(): + tried = np.asarray(intervention.treatments[tx_name].tried_treatment) + tried_uids = all_uids[tried] + if len(tried_uids) > 0: + dur = np.array([month_map.get(uid, 0) for uid in tried_uids], dtype=int) + self.results[f'months_on_{tx_name}_distribution'] = dur + else: + self.results[f'months_on_{tx_name}_distribution'] = np.array([]) + + # ── Combined months on any treatment per HMB woman ── + if len(hmb_uids) > 0: + combined = np.zeros(len(hmb_uids), dtype=int) + for tx_name, month_map in tx_month_maps.items(): + for i, uid in enumerate(hmb_uids): + combined[i] += month_map.get(uid, 0) + self.results['months_on_any_distribution'] = combined + else: + self.results['months_on_any_distribution'] = np.array([]) + + # ── Episode count from orchestrator ── + eps = np.asarray(intervention.care_episodes[hmb_uids]) if len(hmb_uids) > 0 else np.array([]) + self.results['care_episodes_distribution'] = eps + + +# ============================================================================ +# Helpers +# ============================================================================ + +def rr_to_logistic_coeff(rr, p_base=P_BASE): + p_hmb = np.clip(p_base * rr, 1e-6, 1 - 1e-6) + return (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) + + +def make_menstruation(): + coeff = rr_to_logistic_coeff(RR_MID) + return Menstruation(pars={ + 'p_hmb_prone': ss.bernoulli(p=P_HMB_PRONE), + 'hmb_seq': sc.objdict( + poor_mh=sc.objdict(base=0.4, hmb=1.0), + anemic=sc.objdict(base=P_BASE, hmb=coeff), + pain=sc.objdict(base=0.1, hmb=3.36), + ) + }) + + +def make_intervention(care_name, hiud_name): + care_info = CARE_LEVELS[care_name] + hiud_info = HIUD_LEVELS[hiud_name] + + return HMBPool(pars=dict( + year=2020, + intv_year=INTV_YEAR, + care_behavior_pre=CARE_PRE, + care_behavior_post=care_info.care, + prob_offer_pre=0.70, + prob_offer_post=hiud_info.prob_offer, + tx_weights_pre=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + tx_weights_post=hiud_info.tx_weights, + nsaid=sc.objdict(efficacy=0.33, adherence=0.80), + txa=sc.objdict(efficacy=0.45, adherence=0.70), + pill=sc.objdict(efficacy=0.59, adherence=0.80), + hiud=sc.objdict(efficacy=0.88, adherence=1.00), + )) + + +def make_counterfactual(): + return HMBCounterfactual(pars=dict( + care_behavior_pre=CARE_PRE, + care_behavior_post=CARE_PRE, + )) + + +def make_sim(scenario, seed=0): + """ + scenario: 'counterfactual' or tuple (care_name, hiud_name) + """ + mens = make_menstruation() + edu = Education() + + if scenario == 'counterfactual': + intervention = make_counterfactual() + else: + care_name, hiud_name = scenario + intervention = make_intervention(care_name, hiud_name) + + return fp.Sim( + start=START, stop=STOP, + n_agents=10000, total_pop=55_000_000, + location='kenya', + education_module=edu, + connectors=[mens], + interventions=[intervention], + analyzers=[track_hmb_anemia(), TrackTreatmentUsage()], + rand_seed=seed, verbose=0, + ) + + +def scenario_key(care_name, hiud_name): + return f'{care_name}_{hiud_name}' + + +# ============================================================================ +# Run simulations +# ============================================================================ + +def run_analysis(force_rerun=True): + results_file = OUTFOLDER + 'treatment_usage_pool_raw.obj' + + if not force_rerun and os.path.exists(results_file): + print("Loading saved results...") + return sc.loadobj(results_file) + + # Build scenario list + all_scenarios = ['counterfactual'] + for care_name in CARE_LEVELS: + for hiud_name in HIUD_LEVELS: + all_scenarios.append((care_name, hiud_name)) + + raw = {} + for scen in all_scenarios: + key = scen if scen == 'counterfactual' else scenario_key(scen[0], scen[1]) + raw[key] = { + # Monthly time series + 'pct_seeking_care': [], 'pct_seeking_anemic': [], + 'pct_seeking_pain': [], 'pct_seeking_base': [], + 'pct_seekers_w_anemia': [], 'pct_seekers_w_pain': [], + 'pct_seekers_w_base': [], + 'pct_on_any': [], 'pct_on_nsaid': [], 'pct_on_txa': [], + 'pct_on_pill': [], 'pct_on_hiud': [], + 'n_ever_sought': [], 'n_ever_seekers': [], + 'n_gave_up': [], 'n_hysterectomy': [], + 'pct_hysterectomy_all_women': [], + # Person-month time series + 'pm_seeking': [], + 'pm_on_nsaid': [], 'pm_on_txa': [], + 'pm_on_pill': [], 'pm_on_hiud': [], + 'pm_on_any': [], + # Snapshot distributions + 'care_visit_distributions': [], + 'care_visit_means': [], 'care_visit_maxes': [], + 'care_episodes_distributions': [], + 'months_seeking_distributions': [], + 'months_on_nsaid_distributions': [], + 'months_on_txa_distributions': [], + 'months_on_pill_distributions': [], + 'months_on_hiud_distributions': [], + 'months_on_any_distributions': [], + # Anemia + 'hmb_anemia_monthly': [], + } + + for scen in all_scenarios: + key = scen if scen == 'counterfactual' else scenario_key(scen[0], scen[1]) + if scen == 'counterfactual': + label = 'Status quo (10% care, no hIUD)' + else: + label = f'{CARE_LEVELS[scen[0]].label} + {HIUD_LEVELS[scen[1]].label}' + + print(f"\n{'='*60}") + print(f" {label}") + print(f"{'='*60}") + + for seed in range(N_SEEDS): + print(f" seed {seed}...", end=" ", flush=True) + sim = make_sim(scen, seed=seed) + sim.run() + + tu = sim.results.track_treatment_usage + + # Time series + for metric in [ + 'pct_seeking_care', 'pct_seeking_anemic', + 'pct_seeking_pain', 'pct_seeking_base', + 'pct_on_any', 'pct_on_nsaid', 'pct_on_txa', + 'pct_on_pill', 'pct_on_hiud', 'n_ever_sought', + 'n_ever_seekers', 'n_gave_up', 'n_hysterectomy', + 'pct_hysterectomy_all_women', + 'pct_seekers_w_anemia', 'pct_seekers_w_pain', 'pct_seekers_w_base', + 'pm_seeking', 'pm_on_nsaid', 'pm_on_txa', + 'pm_on_pill', 'pm_on_hiud', 'pm_on_any', + ]: + raw[key][metric].append(np.asarray(tu[metric])) + + # Snapshot distributions + raw[key]['care_visit_distributions'].append(np.asarray(tu['care_visit_distribution'])) + raw[key]['care_visit_means'].append(tu['care_visit_mean']) + raw[key]['care_visit_maxes'].append(tu['care_visit_max']) + raw[key]['care_episodes_distributions'].append(np.asarray(tu['care_episodes_distribution'])) + raw[key]['months_seeking_distributions'].append(np.asarray(tu['months_seeking_distribution'])) + for tx_name in ['nsaid', 'txa', 'pill', 'hiud']: + raw[key][f'months_on_{tx_name}_distributions'].append( + np.asarray(tu[f'months_on_{tx_name}_distribution'])) + raw[key]['months_on_any_distributions'].append(np.asarray(tu['months_on_any_distribution'])) + + # Anemia + raw[key]['hmb_anemia_monthly'].append( + np.asarray(sim.results.track_hmb_anemia['n_anemia_with_hmb'])) + + del sim; gc.collect() + print("done") + + sc.saveobj(results_file, raw) + print(f"\nSaved: {results_file}") + return raw + + +# ============================================================================ +# Plotting +# ============================================================================ + +def plot_care_seeking_breakdown(raw, years_monthly): + """Stacked area: % of HMB women seeking care by condition. 3 panels by care level.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True) + fig.suptitle('% of HMB women seeking care each month (by condition)\n' + f'Mid hIUD (25%), mid RR ({RR_MID})', fontsize=14) + hiud_name = 'mid_hiud' + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + key = scenario_key(care_name, hiud_name) + + pct_base = np.array(raw[key]['pct_seeking_base']).mean(axis=0) + pct_pain = np.array(raw[key]['pct_seeking_pain']).mean(axis=0) + pct_anemic = np.array(raw[key]['pct_seeking_anemic']).mean(axis=0) + + ax.fill_between(years_monthly, 0, pct_base, + color='#90CAF9', alpha=0.8, label='Base (no anemia/pain)') + ax.fill_between(years_monthly, pct_base, pct_base + pct_pain, + color='#FFB74D', alpha=0.8, label='With pain') + ax.fill_between(years_monthly, pct_base + pct_pain, + pct_base + pct_pain + pct_anemic, + color='#EF5350', alpha=0.8, label='With anemia') + + pct_total = np.array(raw[key]['pct_seeking_care']).mean(axis=0) + ax.plot(years_monthly, pct_total, color='black', lw=1.5, label='Total') + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('% of HMB women seeking care') + ax.set_title(care_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8, loc='upper left') + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'care_seeking_breakdown.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}care_seeking_breakdown.png") + return fig + + +def plot_care_seeker_composition(raw, years_monthly): + """Among care seekers: % with anemia / pain / neither.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True) + fig.suptitle('Composition of care seekers: % with anemia, pain, or neither\n' + f'Mid hIUD (25%), mid RR ({RR_MID})', fontsize=14) + hiud_name = 'mid_hiud' + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + key = scenario_key(care_name, hiud_name) + + pct_base = np.array(raw[key]['pct_seekers_w_base']).mean(axis=0) + pct_pain = np.array(raw[key]['pct_seekers_w_pain']).mean(axis=0) + pct_anemia = np.array(raw[key]['pct_seekers_w_anemia']).mean(axis=0) + + ax.fill_between(years_monthly, 0, pct_base, + color='#90CAF9', alpha=0.8, label='Neither') + ax.fill_between(years_monthly, pct_base, pct_base + pct_pain, + color='#FFB74D', alpha=0.8, label='With pain') + ax.fill_between(years_monthly, pct_base + pct_pain, + pct_base + pct_pain + pct_anemia, + color='#EF5350', alpha=0.8, label='With anemia') + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('% of care seekers') + ax.set_title(care_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(0, 100) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8, loc='lower right') + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'care_seeker_composition.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}care_seeker_composition.png") + return fig + + +def plot_treatment_distribution(raw, years_monthly): + """Stacked area: % of HMB women on each treatment. 3 panels by care level.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True) + fig.suptitle(f'% of HMB women on each treatment\nMid hIUD (25%), mid RR ({RR_MID})', fontsize=14) + hiud_name = 'mid_hiud' + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + key = scenario_key(care_name, hiud_name) + + pct_nsaid = np.array(raw[key]['pct_on_nsaid']).mean(axis=0) + pct_txa = np.array(raw[key]['pct_on_txa']).mean(axis=0) + pct_pill = np.array(raw[key]['pct_on_pill']).mean(axis=0) + pct_hiud = np.array(raw[key]['pct_on_hiud']).mean(axis=0) + + ax.fill_between(years_monthly, 0, pct_nsaid, color=TX_COLORS['nsaid'], alpha=0.8, label='NSAID') + ax.fill_between(years_monthly, pct_nsaid, pct_nsaid + pct_txa, color=TX_COLORS['txa'], alpha=0.8, label='TXA') + ax.fill_between(years_monthly, pct_nsaid + pct_txa, pct_nsaid + pct_txa + pct_pill, + color=TX_COLORS['pill'], alpha=0.8, label='Pill') + ax.fill_between(years_monthly, pct_nsaid + pct_txa + pct_pill, + pct_nsaid + pct_txa + pct_pill + pct_hiud, + color=TX_COLORS['hiud'], alpha=0.8, label='hIUD') + + pct_any = np.array(raw[key]['pct_on_any']).mean(axis=0) + ax.plot(years_monthly, pct_any, color='black', lw=1.5, label='Any treatment') + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('% of HMB women on treatment') + ax.set_title(care_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(0, 30) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8, loc='upper left') + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'treatment_distribution.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}treatment_distribution.png") + return fig + + +def plot_treatment_by_hiud(raw, years_monthly): + """Treatment distribution by hIUD level. Mid care (20%).""" + care_name = 'mid' + fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True) + fig.suptitle(f'Treatment distribution by hIUD uptake\n' + f'{CARE_LEVELS[care_name].label}, mid RR ({RR_MID})', fontsize=14) + + for idx, (hiud_name, hiud_info) in enumerate(HIUD_LEVELS.items()): + ax = axes[idx] + key = scenario_key(care_name, hiud_name) + + pct_nsaid = np.array(raw[key]['pct_on_nsaid']).mean(axis=0) + pct_txa = np.array(raw[key]['pct_on_txa']).mean(axis=0) + pct_pill = np.array(raw[key]['pct_on_pill']).mean(axis=0) + pct_hiud = np.array(raw[key]['pct_on_hiud']).mean(axis=0) + + ax.fill_between(years_monthly, 0, pct_nsaid, color=TX_COLORS['nsaid'], alpha=0.8, label='NSAID') + ax.fill_between(years_monthly, pct_nsaid, pct_nsaid + pct_txa, color=TX_COLORS['txa'], alpha=0.8, label='TXA') + ax.fill_between(years_monthly, pct_nsaid + pct_txa, pct_nsaid + pct_txa + pct_pill, + color=TX_COLORS['pill'], alpha=0.8, label='Pill') + ax.fill_between(years_monthly, pct_nsaid + pct_txa + pct_pill, + pct_nsaid + pct_txa + pct_pill + pct_hiud, + color=TX_COLORS['hiud'], alpha=0.8, label='hIUD') + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('% of HMB women on treatment') + ax.set_title(hiud_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(0, 30) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8, loc='upper left') + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'treatment_by_hiud.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}treatment_by_hiud.png") + return fig + + +def plot_care_visit_distribution(raw): + """Histogram of care visits per HMB woman. 3 panels by care level.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=False) + fig.suptitle(f'Care-seeking visits per HMB woman\nMid hIUD (25%), mid RR ({RR_MID})', fontsize=14) + hiud_name = 'mid_hiud' + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + key = scenario_key(care_name, hiud_name) + all_visits = np.concatenate(raw[key]['care_visit_distributions']) + if len(all_visits) == 0: + continue + + max_visits = int(all_visits.max()) + bins = np.arange(0, min(max_visits + 2, 15)) - 0.5 + ax.hist(all_visits, bins=bins, color=CARE_COLORS[care_name], + alpha=0.7, edgecolor='black', linewidth=0.5, density=True) + + mean_v = all_visits.mean() + pct_0 = 100 * np.count_nonzero(all_visits == 0) / len(all_visits) + pct_3p = 100 * np.count_nonzero(all_visits >= 3) / len(all_visits) + + stats_text = (f'Mean: {mean_v:.1f}\nNever sought: {pct_0:.0f}%\n' + f'3+ visits: {pct_3p:.1f}%\nMax: {max_visits}') + ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, ha='right', va='top', + fontsize=9, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + ax.set_xlabel('Number of care-seeking visits') + if idx == 0: + ax.set_ylabel('Density') + ax.set_title(care_info.label) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'care_visit_distribution.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}care_visit_distribution.png") + return fig + + +def plot_gave_up_hysterectomy(raw, years_monthly): + """Gave up and hysterectomy counts over time. 3 panels by care level.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle(f'Cumulative gave-up and hysterectomy\nMid hIUD (25%), mid RR ({RR_MID})', fontsize=14) + hiud_name = 'mid_hiud' + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + key = scenario_key(care_name, hiud_name) + + gave_up = np.array(raw[key]['n_gave_up']).mean(axis=0) + hyst = np.array(raw[key]['n_hysterectomy']).mean(axis=0) + + ax.plot(years_monthly, gave_up / 1e6, color='#E57373', lw=2, label='Gave up') + ax.plot(years_monthly, hyst / 1e6, color='#7B1FA2', lw=2, label='Hysterectomy') + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('Cumulative count (millions)') + ax.set_title(care_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + sc.SIticks(ax=ax) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'gave_up_hysterectomy.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}gave_up_hysterectomy.png") + return fig + + +def plot_cumulative_person_months(raw, years_monthly): + """Cumulative person-months on each treatment. 3 panels by care level.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True) + fig.suptitle(f'Cumulative person-months of treatment\nMid hIUD (25%), mid RR ({RR_MID})', fontsize=14) + hiud_name = 'mid_hiud' + + for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): + ax = axes[idx] + key = scenario_key(care_name, hiud_name) + + cum_seek = np.cumsum(np.array(raw[key]['pm_seeking']).mean(axis=0)) / 1e6 + cum_nsaid = np.cumsum(np.array(raw[key]['pm_on_nsaid']).mean(axis=0)) / 1e6 + cum_txa = np.cumsum(np.array(raw[key]['pm_on_txa']).mean(axis=0)) / 1e6 + cum_pill = np.cumsum(np.array(raw[key]['pm_on_pill']).mean(axis=0)) / 1e6 + cum_hiud = np.cumsum(np.array(raw[key]['pm_on_hiud']).mean(axis=0)) / 1e6 + + ax.plot(years_monthly, cum_seek, color='grey', lw=2, ls='--', label='Seeking care') + ax.fill_between(years_monthly, 0, cum_nsaid, color=TX_COLORS['nsaid'], alpha=0.7, label='NSAID') + ax.fill_between(years_monthly, cum_nsaid, cum_nsaid + cum_txa, color=TX_COLORS['txa'], alpha=0.7, label='TXA') + ax.fill_between(years_monthly, cum_nsaid + cum_txa, cum_nsaid + cum_txa + cum_pill, + color=TX_COLORS['pill'], alpha=0.7, label='Pill') + ax.fill_between(years_monthly, cum_nsaid + cum_txa + cum_pill, + cum_nsaid + cum_txa + cum_pill + cum_hiud, + color=TX_COLORS['hiud'], alpha=0.7, label='hIUD') + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel('Cumulative person-months (millions)') + ax.set_title(care_info.label) + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8, loc='upper left') + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'cumulative_person_months.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}cumulative_person_months.png") + return fig + + +def plot_treatment_duration_distributions(raw): + """Months on treatment per woman who tried it. 3 rows × 4 columns.""" + hiud_name = 'mid_hiud' + care_list = list(CARE_LEVELS.keys()) + tx_list = ['nsaid', 'txa', 'pill', 'hiud'] + + fig, axes = plt.subplots(len(care_list), len(tx_list), figsize=(20, 12), sharey='row') + fig.suptitle(f'Months on treatment per woman who tried it\n' + f'Mid hIUD (25%), mid RR ({RR_MID})', fontsize=14) + + for row, care_name in enumerate(care_list): + key = scenario_key(care_name, hiud_name) + for col, tx_name in enumerate(tx_list): + ax = axes[row, col] + all_dur = np.concatenate(raw[key][f'months_on_{tx_name}_distributions']) + + if len(all_dur) == 0: + ax.text(0.5, 0.5, 'No data', ha='center', va='center', + transform=ax.transAxes, fontsize=12) + continue + + used = all_dur[all_dur > 0] + if len(used) == 0: + ax.text(0.5, 0.5, 'No usage', ha='center', va='center', + transform=ax.transAxes, fontsize=12) + continue + + max_dur = int(used.max()) + bins = np.arange(0.5, min(max_dur + 1.5, 50), 1) + ax.hist(used, bins=bins, color=TX_COLORS[tx_name], + alpha=0.7, edgecolor='black', linewidth=0.5, density=True) + + stats_text = (f'Mean: {used.mean():.1f} mo\n' + f'Median: {np.median(used):.0f} mo\n' + f'N tried: {len(all_dur)}') + ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, + ha='right', va='top', fontsize=8, + bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + ax.set_xlabel('Months on treatment') + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + if row == 0: + ax.set_title(TX_LABELS[tx_name], fontsize=12) + if col == 0: + ax.set_ylabel(f'{CARE_LEVELS[care_name].label}\nDensity') + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'treatment_duration_distributions.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}treatment_duration_distributions.png") + return fig + + +# ============================================================================ +# Summary tables +# ============================================================================ + +def print_treatment_summary(raw): + hiud_name = 'mid_hiud' + print(f"\n{'═'*100}") + print(f" TREATMENT USAGE SUMMARY (end-of-sim, mid hIUD 25%)") + print(f"{'═'*100}") + + print(f"\n {'Scenario':<20} {'On any':>8} {'NSAID':>8} {'TXA':>8} " + f"{'Pill':>8} {'hIUD':>8} {'Seeking':>8}") + print(f" {'─'*80}") + + for key_label, key_val in [('Status quo', 'counterfactual')] + \ + [(CARE_LEVELS[c].label, scenario_key(c, hiud_name)) for c in CARE_LEVELS]: + pct_any = np.array(raw[key_val]['pct_on_any'])[:, -1].mean() + pct_nsaid = np.array(raw[key_val]['pct_on_nsaid'])[:, -1].mean() + pct_txa = np.array(raw[key_val]['pct_on_txa'])[:, -1].mean() + pct_pill = np.array(raw[key_val]['pct_on_pill'])[:, -1].mean() + pct_hiud = np.array(raw[key_val]['pct_on_hiud'])[:, -1].mean() + pct_seek = np.array(raw[key_val]['pct_seeking_care'])[:, -1].mean() + print(f" {key_label:<20} {pct_any:>7.1f}% {pct_nsaid:>7.1f}% " + f"{pct_txa:>7.1f}% {pct_pill:>7.1f}% {pct_hiud:>7.1f}% {pct_seek:>7.1f}%") + + # Full 3×3 grid + print(f"\n FULL 3×3: % of HMB women on any treatment (end-of-sim)") + print(f" {'':>20}", end="") + for hiud_name, hiud_info in HIUD_LEVELS.items(): + print(f" {hiud_info.label:>22}", end="") + print() + print(f" {'─'*90}") + + for care_name, care_info in CARE_LEVELS.items(): + print(f" {care_info.label:<20}", end="") + for hiud_name in HIUD_LEVELS: + key = scenario_key(care_name, hiud_name) + pct = np.array(raw[key]['pct_on_any'])[:, -1] + print(f" {pct.mean():>7.1f}±{pct.std():.1f}%", end="") + print() + + print(f"\n{'═'*100}") + + +def print_gave_up_summary(raw): + hiud_name = 'mid_hiud' + print(f"\n{'═'*80}") + print(f" GAVE UP / HYSTERECTOMY (end-of-sim, mid hIUD 25%)") + print(f"{'═'*80}") + + print(f"\n {'Scenario':<20} {'Gave up (M)':>14} {'Hysterectomy (M)':>18} {'Hyst % women':>14}") + print(f" {'─'*70}") + + for key_label, key_val in [('Status quo', 'counterfactual')] + \ + [(CARE_LEVELS[c].label, scenario_key(c, hiud_name)) for c in CARE_LEVELS]: + gu = np.array(raw[key_val]['n_gave_up'])[:, -1].mean() / 1e6 + hy = np.array(raw[key_val]['n_hysterectomy'])[:, -1].mean() / 1e6 + pct_h = np.array(raw[key_val]['pct_hysterectomy_all_women'])[:, -1].mean() + print(f" {key_label:<20} {gu:>10.2f}M {hy:>14.2f}M {pct_h:>12.2f}%") + + print(f"\n{'═'*80}") + + +def print_person_month_summary(raw): + hiud_name = 'mid_hiud' + print(f"\n{'═'*100}") + print(f" PERSON-MONTH SUMMARY (mid hIUD 25%)") + print(f"{'═'*100}") + + print(f"\n Cumulative person-months (millions, full sim {START}–{STOP}):") + print(f" {'Scenario':<20} {'Seeking':>10} {'NSAID':>10} {'TXA':>10} " + f"{'Pill':>10} {'hIUD':>10} {'Any Tx':>10}") + print(f" {'─'*85}") + + for key_label, key_val in [('Status quo', 'counterfactual')] + \ + [(CARE_LEVELS[c].label, scenario_key(c, hiud_name)) for c in CARE_LEVELS]: + cum_seek = np.array(raw[key_val]['pm_seeking']).sum(axis=1).mean() / 1e6 + cum_nsaid = np.array(raw[key_val]['pm_on_nsaid']).sum(axis=1).mean() / 1e6 + cum_txa = np.array(raw[key_val]['pm_on_txa']).sum(axis=1).mean() / 1e6 + cum_pill = np.array(raw[key_val]['pm_on_pill']).sum(axis=1).mean() / 1e6 + cum_hiud = np.array(raw[key_val]['pm_on_hiud']).sum(axis=1).mean() / 1e6 + cum_any = np.array(raw[key_val]['pm_on_any']).sum(axis=1).mean() / 1e6 + print(f" {key_label:<20} {cum_seek:>9.1f}M {cum_nsaid:>9.1f}M " + f"{cum_txa:>9.1f}M {cum_pill:>9.1f}M {cum_hiud:>9.1f}M {cum_any:>9.1f}M") + + # Per-person treatment duration + print(f"\n Mean months on treatment per woman who tried it:") + print(f" {'Scenario':<20} {'NSAID':>10} {'TXA':>10} {'Pill':>10} {'hIUD':>10}") + print(f" {'─'*65}") + + for key_label, key_val in [('Status quo', 'counterfactual')] + \ + [(CARE_LEVELS[c].label, scenario_key(c, hiud_name)) for c in CARE_LEVELS]: + means = [] + for tx_name in ['nsaid', 'txa', 'pill', 'hiud']: + all_dur = np.concatenate(raw[key_val][f'months_on_{tx_name}_distributions']) + used = all_dur[all_dur > 0] if len(all_dur) > 0 else np.array([0]) + means.append(used.mean() if len(used) > 0 else 0) + print(f" {key_label:<20} {means[0]:>9.1f}m {means[1]:>9.1f}m " + f"{means[2]:>9.1f}m {means[3]:>9.1f}m") + + print(f"\n{'═'*100}") + + +def print_care_visit_summary(raw): + hiud_name = 'mid_hiud' + print(f"\n{'═'*80}") + print(f" CARE VISIT DISTRIBUTION (mid hIUD 25%)") + print(f"{'═'*80}") + + print(f"\n {'Scenario':<20} {'Mean':>6} {'Median':>8} {'Max':>6} " + f"{'Never':>8} {'3+':>8}") + print(f" {'─'*65}") + + for key_label, key_val in [('Status quo', 'counterfactual')] + \ + [(CARE_LEVELS[c].label, scenario_key(c, hiud_name)) for c in CARE_LEVELS]: + all_visits = np.concatenate(raw[key_val]['care_visit_distributions']) + if len(all_visits) == 0: + continue + pct_0 = 100 * np.count_nonzero(all_visits == 0) / len(all_visits) + pct_3p = 100 * np.count_nonzero(all_visits >= 3) / len(all_visits) + print(f" {key_label:<20} {all_visits.mean():>5.1f} {np.median(all_visits):>7.0f} " + f"{all_visits.max():>5.0f} {pct_0:>7.1f}% {pct_3p:>7.1f}%") + + print(f"\n{'═'*80}") + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == '__main__': + + do_run = True + raw = run_analysis(force_rerun=do_run) + + # Time axes + first_key = list(raw.keys())[0] + n_months = len(raw[first_key]['pct_seeking_care'][0]) + years_monthly = np.array([START + m / 12 for m in range(n_months)]) + + # Plots + plot_care_seeking_breakdown(raw, years_monthly) + plot_care_seeker_composition(raw, years_monthly) + plot_treatment_distribution(raw, years_monthly) + plot_treatment_by_hiud(raw, years_monthly) + plot_care_visit_distribution(raw) + plot_gave_up_hysterectomy(raw, years_monthly) + plot_cumulative_person_months(raw, years_monthly) + plot_treatment_duration_distributions(raw) + + # Summary tables + print_treatment_summary(raw) + print_gave_up_summary(raw) + print_person_month_summary(raw) + print_care_visit_summary(raw) \ No newline at end of file From d0144fe99a89e7aa6daef1e7cb8c4b0b3c283d39 Mon Sep 17 00:00:00 2001 From: nnoori-IDM <42287387+nnoori-IDM@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:39:31 -0700 Subject: [PATCH 3/5] Sensitivity analysis of episode cap --- SA_episode_cap.py | 398 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 SA_episode_cap.py diff --git a/SA_episode_cap.py b/SA_episode_cap.py new file mode 100644 index 0000000..756cf3e --- /dev/null +++ b/SA_episode_cap.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Apr 28 09:38:33 2026 + +@author: navidehno +""" + +""" +Sensitivity analysis: max care-seeking episodes. +Baseline: 3 episodes (current). Alternative: 6 episodes. + +Scenario: Mid care (20%) + Mid hIUD (25%) vs status quo. +Mid RR (1.73) throughout. + +Question: Does allowing more care-seeking attempts before give-up/hysterectomy +meaningfully increase treatment coverage and HMB/anemia reduction? +""" + +import numpy as np +import sciris as sc +import starsim as ss +import fpsim as fp +import os +import gc +import matplotlib.pyplot as plt + +from menstruation import Menstruation +from education import Education +from interventions_pool import HMBPool, HMBCounterfactual +from analyzers import track_hmb_anemia + + +# ── Output folders ───────────────────────────────────────────────────────────── +PLOTFOLDER = 'figures_episode_sa/' +OUTFOLDER = 'results_episode_sa/' +for d in [PLOTFOLDER, OUTFOLDER]: + os.makedirs(d, exist_ok=True) + + +# ── Settings ─────────────────────────────────────────────────────────────────── +P_BASE = 0.215 +P_HMB_PRONE = 0.53 +RR_MID = 1.73 +N_SEEDS = 10 +START = 2020 +STOP = 2030 +INTV_YEAR = 2026 + +EPISODE_CAPS = [3, 6] + +CARE_PRE = sc.objdict( + p_ever_seek=0.10, + p_monthly_first=1/36, + p_monthly_first_anemic=1/12, + p_monthly_first_pain=1/24, + p_monthly_repeat=1/6, +) + +CARE_POST = sc.objdict( + p_ever_seek=0.20, + p_monthly_first=1/12, + p_monthly_first_anemic=1/6, + p_monthly_first_pain=1/12, + p_monthly_repeat=1/6, +) + +HIUD_MID = sc.objdict( + prob_offer=0.80, + tx_weights=sc.objdict(nsaid=34.4, txa=17.2, pill=17.2, hiud=31.3), +) + +# Colors +CAP_COLORS = {3: '#2196F3', 6: '#F44336'} +CAP_LABELS = {3: 'Max 3 episodes', 6: 'Max 6 episodes'} + + +# ── Helpers ──────────────────────────────────────────────────────────────────── +def rr_to_logistic_coeff(rr, p_base=P_BASE): + p_hmb = np.clip(p_base * rr, 1e-6, 1 - 1e-6) + return (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) + + +def make_menstruation(): + coeff = rr_to_logistic_coeff(RR_MID) + return Menstruation(pars={ + 'p_hmb_prone': ss.bernoulli(p=P_HMB_PRONE), + 'hmb_seq': sc.objdict( + poor_mh=sc.objdict(base=0.4, hmb=1.0), + anemic=sc.objdict(base=P_BASE, hmb=coeff), + pain=sc.objdict(base=0.1, hmb=3.36), + ) + }) + + +def make_intervention(max_episodes): + return HMBPool(pars=dict( + year=2020, + intv_year=INTV_YEAR, + care_behavior_pre=CARE_PRE, + care_behavior_post=CARE_POST, + prob_offer_pre=0.70, + prob_offer_post=HIUD_MID.prob_offer, + tx_weights_pre=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + tx_weights_post=HIUD_MID.tx_weights, + nsaid=sc.objdict(efficacy=0.33, adherence=0.80), + txa=sc.objdict(efficacy=0.45, adherence=0.70), + pill=sc.objdict(efficacy=0.59, adherence=0.80), + hiud=sc.objdict(efficacy=0.88, adherence=1.00), + max_care_episodes=max_episodes, + )) + + +def make_counterfactual(max_episodes): + return HMBCounterfactual(pars=dict( + care_behavior_pre=CARE_PRE, + care_behavior_post=CARE_PRE, + max_care_episodes=max_episodes, + )) + + +def make_sim(scenario, max_episodes, seed=0): + mens = make_menstruation() + edu = Education() + + if scenario == 'counterfactual': + intervention = make_counterfactual(max_episodes) + else: + intervention = make_intervention(max_episodes) + + return fp.Sim( + start=START, stop=STOP, + n_agents=10000, total_pop=55_000_000, + location='kenya', + education_module=edu, + connectors=[mens], + interventions=[intervention], + analyzers=[track_hmb_anemia()], + rand_seed=seed, verbose=0, + ) + + +# ── Run ──────────────────────────────────────────────────────────────────────── +def _annualize(monthly_arr, how='sum'): + arr = np.asarray(monthly_arr) + n_y = len(arr) // 12 + arr = arr[:12 * n_y].reshape(n_y, 12) + return arr.sum(axis=1) if how == 'sum' else arr[:, -1] + + +def _annualize_mean(monthly_arr): + arr = np.asarray(monthly_arr) + n_y = len(arr) // 12 + arr = arr[:12 * n_y].reshape(n_y, 12) + return arr.mean(axis=1) + + +def run_sa(force_rerun=True): + results_file = OUTFOLDER + 'episode_cap_sa_raw.obj' + + if not force_rerun and os.path.exists(results_file): + print("Loading saved results...") + return sc.loadobj(results_file) + + raw = {} + for cap in EPISODE_CAPS: + for scen in ['counterfactual', 'intervention']: + key = f'{scen}_cap{cap}' + raw[key] = { + 'hmb_prev_monthly': [], + 'anemia_hmb_monthly': [], + 'anemia_total_monthly': [], + 'hmb_prev_annual': [], + 'anemia_hmb_annual': [], + 'anemia_total_annual': [], + } + + for cap in EPISODE_CAPS: + for scen in ['counterfactual', 'intervention']: + key = f'{scen}_cap{cap}' + label = f'{"Status quo" if scen == "counterfactual" else "Mid care + Mid hIUD"}, max {cap} episodes' + + print(f"\n{'='*60}") + print(f" {label}") + print(f"{'='*60}") + + for seed in range(N_SEEDS): + print(f" seed {seed}...", end=" ", flush=True) + sim = make_sim(scen, cap, seed=seed) + sim.run() + + hmb_prev = np.asarray(sim.results.menstruation['hmb_prev']) + anemia_hmb = np.asarray(sim.results.track_hmb_anemia['n_anemia_with_hmb']) + anemia_total = np.asarray(sim.results.track_hmb_anemia['n_anemia_total']) + + raw[key]['hmb_prev_monthly'].append(hmb_prev) + raw[key]['anemia_hmb_monthly'].append(anemia_hmb) + raw[key]['anemia_total_monthly'].append(anemia_total) + raw[key]['hmb_prev_annual'].append(_annualize_mean(hmb_prev)) + raw[key]['anemia_hmb_annual'].append(_annualize(anemia_hmb)) + raw[key]['anemia_total_annual'].append(_annualize(anemia_total)) + + del sim; gc.collect() + print("done") + + sc.saveobj(results_file, raw) + print(f"\nSaved: {results_file}") + return raw + + +# ── Statistics ───────────────────────────────────────────────────────────────── +def compute_pct_reduction(raw, cap, outcome='anemia_hmb_annual'): + cf = np.array(raw[f'counterfactual_cap{cap}'][outcome]) + intv = np.array(raw[f'intervention_cap{cap}'][outcome]) + averted = cf - intv + pct = np.where(cf > 0, averted / cf * 100, np.nan) + return { + 'mean': np.nanmean(pct, axis=0), + 'lower': np.nanpercentile(pct, 2.5, axis=0), + 'upper': np.nanpercentile(pct, 97.5, axis=0), + } + + +# ── Plots ────────────────────────────────────────────────────────────────────── + +def plot_hmb_prevalence(raw, years_monthly): + """HMB prevalence: counterfactual + intervention for each cap.""" + fig, axes = plt.subplots(1, 2, figsize=(16, 5)) + fig.suptitle(f'HMB prevalence by max care-seeking episodes\n' + f'Mid care (20%) + Mid hIUD (25%), mid RR ({RR_MID})', fontsize=13) + + # Left: absolute + ax = axes[0] + for cap in EPISODE_CAPS: + cf = np.array(raw[f'counterfactual_cap{cap}']['hmb_prev_monthly']) + intv = np.array(raw[f'intervention_cap{cap}']['hmb_prev_monthly']) + + ax.plot(years_monthly, cf.mean(axis=0), color=CAP_COLORS[cap], ls='--', lw=1.5, + label=f'Status quo (max {cap})') + ax.plot(years_monthly, intv.mean(axis=0), color=CAP_COLORS[cap], lw=2, + label=f'Intervention (max {cap})') + ax.fill_between(years_monthly, intv.mean(axis=0) - intv.std(axis=0), + intv.mean(axis=0) + intv.std(axis=0), + color=CAP_COLORS[cap], alpha=0.12) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year'); ax.set_ylabel('HMB prevalence') + ax.set_title('Absolute') + ax.set_xlim([START, STOP]) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8) + + # Right: % reduction + ax = axes[1] + n_years = len(raw[f'counterfactual_cap3']['hmb_prev_annual'][0]) + years_full = np.arange(START, START + n_years) + post_mask = years_full >= INTV_YEAR + + for cap in EPISODE_CAPS: + s = compute_pct_reduction(raw, cap, outcome='hmb_prev_annual') + mean = np.where(post_mask, s['mean'], np.nan) + lower = np.where(post_mask, s['lower'], np.nan) + upper = np.where(post_mask, s['upper'], np.nan) + + ax.plot(years_full, mean, color=CAP_COLORS[cap], lw=2.5, label=CAP_LABELS[cap]) + ax.fill_between(years_full, lower, upper, color=CAP_COLORS[cap], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year'); ax.set_ylabel('% reduction vs status quo') + ax.set_title('% reduction in HMB prevalence') + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'hmb_prev_episode_sa.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}hmb_prev_episode_sa.png") + return fig + + +def plot_anemia(raw, years_monthly): + """Anemia among HMB women: counterfactual + intervention for each cap.""" + fig, axes = plt.subplots(1, 2, figsize=(16, 5)) + fig.suptitle(f'HMB-related anemia by max care-seeking episodes\n' + f'Mid care (20%) + Mid hIUD (25%), mid RR ({RR_MID})', fontsize=13) + + # Left: absolute + ax = axes[0] + for cap in EPISODE_CAPS: + cf = np.array(raw[f'counterfactual_cap{cap}']['anemia_hmb_monthly']) + intv = np.array(raw[f'intervention_cap{cap}']['anemia_hmb_monthly']) + + ax.plot(years_monthly, cf.mean(axis=0), color=CAP_COLORS[cap], ls='--', lw=1.5, + label=f'Status quo (max {cap})') + ax.plot(years_monthly, intv.mean(axis=0), color=CAP_COLORS[cap], lw=2, + label=f'Intervention (max {cap})') + ax.fill_between(years_monthly, intv.mean(axis=0) - intv.std(axis=0), + intv.mean(axis=0) + intv.std(axis=0), + color=CAP_COLORS[cap], alpha=0.12) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year'); ax.set_ylabel('Monthly anemia cases (HMB women)') + ax.set_title('Absolute') + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8) + sc.SIticks(ax=ax) + + # Right: % reduction + ax = axes[1] + n_years = len(raw[f'counterfactual_cap3']['anemia_hmb_annual'][0]) + years_full = np.arange(START, START + n_years) + post_mask = years_full >= INTV_YEAR + + for cap in EPISODE_CAPS: + s = compute_pct_reduction(raw, cap, outcome='anemia_hmb_annual') + mean = np.where(post_mask, s['mean'], np.nan) + lower = np.where(post_mask, s['lower'], np.nan) + upper = np.where(post_mask, s['upper'], np.nan) + + ax.plot(years_full, mean, color=CAP_COLORS[cap], lw=2.5, label=CAP_LABELS[cap]) + ax.fill_between(years_full, lower, upper, color=CAP_COLORS[cap], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year'); ax.set_ylabel('% reduction vs status quo') + ax.set_title('% reduction in HMB-related anemia') + ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=9) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + 'anemia_episode_sa.png', dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}anemia_episode_sa.png") + return fig + + +# ── Summary ──────────────────────────────────────────────────────────────────── +def print_summary(raw): + n_years = len(raw['counterfactual_cap3']['hmb_prev_annual'][0]) + years = np.arange(START, START + n_years) + post_mask = years >= INTV_YEAR + + print(f"\n{'═'*90}") + print(f" Episode Cap Sensitivity Analysis") + print(f" Mid care (20%) + Mid hIUD (25%), mid RR ({RR_MID})") + print(f"{'═'*90}") + + print(f"\n End-of-sim values:") + print(f" {'Scenario':<35} {'HMB prev':>10} {'HMB anemia (M)':>16}") + print(f" {'─'*65}") + + for cap in EPISODE_CAPS: + for scen, label in [('counterfactual', 'Status quo'), + ('intervention', 'Mid care + Mid hIUD')]: + key = f'{scen}_cap{cap}' + hmb = np.array(raw[key]['hmb_prev_annual'])[:, -1] + ane = np.array(raw[key]['anemia_hmb_annual'])[:, -1] + print(f" {f'{label} (max {cap})':<35} " + f"{hmb.mean():.4f}±{hmb.std():.4f} " + f"{ane.mean()/1e6:>8.2f}±{ane.std()/1e6:.2f}M") + + print(f"\n % reduction vs status quo at {STOP}:") + print(f" {'Episode cap':<20} {'HMB prev':>12} {'HMB anemia':>12}") + print(f" {'─'*50}") + + for cap in EPISODE_CAPS: + hmb_s = compute_pct_reduction(raw, cap, outcome='hmb_prev_annual') + ane_s = compute_pct_reduction(raw, cap, outcome='anemia_hmb_annual') + print(f" {f'Max {cap} episodes':<20} {hmb_s['mean'][-1]:>10.1f}% {ane_s['mean'][-1]:>10.1f}%") + + # Difference + hmb_3 = compute_pct_reduction(raw, 3, 'hmb_prev_annual')['mean'][-1] + hmb_6 = compute_pct_reduction(raw, 6, 'hmb_prev_annual')['mean'][-1] + ane_3 = compute_pct_reduction(raw, 3, 'anemia_hmb_annual')['mean'][-1] + ane_6 = compute_pct_reduction(raw, 6, 'anemia_hmb_annual')['mean'][-1] + + print(f"\n Additional gain from 6 vs 3 episodes:") + print(f" HMB prevalence: {hmb_6 - hmb_3:+.1f} percentage points") + print(f" HMB anemia: {ane_6 - ane_3:+.1f} percentage points") + + print(f"\n{'═'*90}\n") + + +# ── Main ─────────────────────────────────────────────────────────────────────── +if __name__ == '__main__': + + do_run = True + raw = run_sa(force_rerun=do_run) + + # Time axes + n_months = len(raw['counterfactual_cap3']['hmb_prev_monthly'][0]) + years_monthly = np.array([START + m / 12 for m in range(n_months)]) + + # Plots + plot_hmb_prevalence(raw, years_monthly) + plot_anemia(raw, years_monthly) + + # Summary + print_summary(raw) \ No newline at end of file From f6dea2937e09f08d0b0442001cd980864ba6e26a Mon Sep 17 00:00:00 2001 From: nnoori-IDM <42287387+nnoori-IDM@users.noreply.github.com> Date: Thu, 30 Apr 2026 08:32:21 -0700 Subject: [PATCH 4/5] run the analysis longer and modify the plots --- SA_episode_cap.py | 84 +++++++++++++++-------- calibrate_p_hmb.py | 2 +- run_anemia_risk_sensitivity.py | 36 +++++++--- run_scenarios.py | 121 ++++++++++++++++++++++++++++----- stats_interventions2.py | 37 +++++++--- 5 files changed, 213 insertions(+), 67 deletions(-) diff --git a/SA_episode_cap.py b/SA_episode_cap.py index 756cf3e..4ba4c72 100644 --- a/SA_episode_cap.py +++ b/SA_episode_cap.py @@ -42,9 +42,10 @@ P_HMB_PRONE = 0.53 RR_MID = 1.73 N_SEEDS = 10 -START = 2020 -STOP = 2030 +START = 2017 +STOP = 2035 INTV_YEAR = 2026 +PLOT_START = 2020 EPISODE_CAPS = [3, 6] @@ -80,6 +81,13 @@ def rr_to_logistic_coeff(rr, p_base=P_BASE): return (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) +def _mask_pre_intv(years, values): + """Set values before INTV_YEAR to NaN so only post-intervention is plotted.""" + masked = np.array(values, dtype=float).copy() + masked[years < INTV_YEAR] = np.nan + return masked + + def make_menstruation(): coeff = rr_to_logistic_coeff(RR_MID) return Menstruation(pars={ @@ -94,7 +102,7 @@ def make_menstruation(): def make_intervention(max_episodes): return HMBPool(pars=dict( - year=2020, + year=2017, intv_year=INTV_YEAR, care_behavior_pre=CARE_PRE, care_behavior_post=CARE_POST, @@ -112,6 +120,7 @@ def make_intervention(max_episodes): def make_counterfactual(max_episodes): return HMBCounterfactual(pars=dict( + year=2017, care_behavior_pre=CARE_PRE, care_behavior_post=CARE_PRE, max_care_episodes=max_episodes, @@ -223,35 +232,44 @@ def compute_pct_reduction(raw, cap, outcome='anemia_hmb_annual'): # ── Plots ────────────────────────────────────────────────────────────────────── def plot_hmb_prevalence(raw, years_monthly): - """HMB prevalence: counterfactual + intervention for each cap.""" + """HMB prevalence: counterfactual + intervention for each cap. + Only status quo shown before INTV_YEAR.""" fig, axes = plt.subplots(1, 2, figsize=(16, 5)) fig.suptitle(f'HMB prevalence by max care-seeking episodes\n' f'Mid care (20%) + Mid hIUD (25%), mid RR ({RR_MID})', fontsize=13) # Left: absolute ax = axes[0] + + # Plot ONE status quo line (they overlap pre-2026; use cap=3 as reference) + cf_ref = np.array(raw['counterfactual_cap3']['hmb_prev_monthly']) + ax.plot(years_monthly, cf_ref.mean(axis=0), color='#6c757d', ls='--', lw=2, + label='Status quo') + + # Intervention lines: post-intervention only for cap in EPISODE_CAPS: - cf = np.array(raw[f'counterfactual_cap{cap}']['hmb_prev_monthly']) intv = np.array(raw[f'intervention_cap{cap}']['hmb_prev_monthly']) - - ax.plot(years_monthly, cf.mean(axis=0), color=CAP_COLORS[cap], ls='--', lw=1.5, - label=f'Status quo (max {cap})') - ax.plot(years_monthly, intv.mean(axis=0), color=CAP_COLORS[cap], lw=2, - label=f'Intervention (max {cap})') - ax.fill_between(years_monthly, intv.mean(axis=0) - intv.std(axis=0), - intv.mean(axis=0) + intv.std(axis=0), + mean = intv.mean(axis=0) + std = intv.std(axis=0) + mean_masked = _mask_pre_intv(years_monthly, mean) + std_masked = _mask_pre_intv(years_monthly, std) + + ax.plot(years_monthly, mean_masked, color=CAP_COLORS[cap], lw=2, + label=f'Intervention ({CAP_LABELS[cap]})') + ax.fill_between(years_monthly, mean_masked - std_masked, + mean_masked + std_masked, color=CAP_COLORS[cap], alpha=0.12) ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year'); ax.set_ylabel('HMB prevalence') ax.set_title('Absolute') - ax.set_xlim([START, STOP]) + ax.set_xlim([PLOT_START, STOP]) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8) # Right: % reduction ax = axes[1] - n_years = len(raw[f'counterfactual_cap3']['hmb_prev_annual'][0]) + n_years = len(raw['counterfactual_cap3']['hmb_prev_annual'][0]) years_full = np.arange(START, START + n_years) post_mask = years_full >= INTV_YEAR @@ -267,7 +285,7 @@ def plot_hmb_prevalence(raw, years_monthly): ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year'); ax.set_ylabel('% reduction vs status quo') ax.set_title('% reduction in HMB prevalence') - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) @@ -278,36 +296,45 @@ def plot_hmb_prevalence(raw, years_monthly): def plot_anemia(raw, years_monthly): - """Anemia among HMB women: counterfactual + intervention for each cap.""" + """Anemia among HMB women: counterfactual + intervention for each cap. + Only status quo shown before INTV_YEAR.""" fig, axes = plt.subplots(1, 2, figsize=(16, 5)) fig.suptitle(f'HMB-related anemia by max care-seeking episodes\n' f'Mid care (20%) + Mid hIUD (25%), mid RR ({RR_MID})', fontsize=13) # Left: absolute ax = axes[0] + + # Plot ONE status quo line + cf_ref = np.array(raw['counterfactual_cap3']['anemia_hmb_monthly']) + ax.plot(years_monthly, cf_ref.mean(axis=0), color='#6c757d', ls='--', lw=2, + label='Status quo') + + # Intervention lines: post-intervention only for cap in EPISODE_CAPS: - cf = np.array(raw[f'counterfactual_cap{cap}']['anemia_hmb_monthly']) intv = np.array(raw[f'intervention_cap{cap}']['anemia_hmb_monthly']) - - ax.plot(years_monthly, cf.mean(axis=0), color=CAP_COLORS[cap], ls='--', lw=1.5, - label=f'Status quo (max {cap})') - ax.plot(years_monthly, intv.mean(axis=0), color=CAP_COLORS[cap], lw=2, - label=f'Intervention (max {cap})') - ax.fill_between(years_monthly, intv.mean(axis=0) - intv.std(axis=0), - intv.mean(axis=0) + intv.std(axis=0), + mean = intv.mean(axis=0) + std = intv.std(axis=0) + mean_masked = _mask_pre_intv(years_monthly, mean) + std_masked = _mask_pre_intv(years_monthly, std) + + ax.plot(years_monthly, mean_masked, color=CAP_COLORS[cap], lw=2, + label=f'Intervention ({CAP_LABELS[cap]})') + ax.fill_between(years_monthly, mean_masked - std_masked, + mean_masked + std_masked, color=CAP_COLORS[cap], alpha=0.12) ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year'); ax.set_ylabel('Monthly anemia cases (HMB women)') ax.set_title('Absolute') - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8) sc.SIticks(ax=ax) # Right: % reduction ax = axes[1] - n_years = len(raw[f'counterfactual_cap3']['anemia_hmb_annual'][0]) + n_years = len(raw['counterfactual_cap3']['anemia_hmb_annual'][0]) years_full = np.arange(START, START + n_years) post_mask = years_full >= INTV_YEAR @@ -323,7 +350,7 @@ def plot_anemia(raw, years_monthly): ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year'); ax.set_ylabel('% reduction vs status quo') ax.set_title('% reduction in HMB-related anemia') - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) @@ -395,4 +422,5 @@ def print_summary(raw): plot_anemia(raw, years_monthly) # Summary - print_summary(raw) \ No newline at end of file + print_summary(raw) + raw = sc.loadobj('results_episode_sa/episode_cap_sa_raw.obj') \ No newline at end of file diff --git a/calibrate_p_hmb.py b/calibrate_p_hmb.py index 3ac819d..a18fe39 100644 --- a/calibrate_p_hmb.py +++ b/calibrate_p_hmb.py @@ -22,7 +22,7 @@ from analyzers import track_hmb_anemia # ── Settings ── -START = 2020 +START = 2017 STOP = 2030 N_SEEDS = 5 TARGET_PREV = 0.48 diff --git a/run_anemia_risk_sensitivity.py b/run_anemia_risk_sensitivity.py index bdd2b06..8eb7e61 100644 --- a/run_anemia_risk_sensitivity.py +++ b/run_anemia_risk_sensitivity.py @@ -40,9 +40,10 @@ P_BASE = 0.215 P_HMB_PRONE = 0.53 N_SEEDS = 10 -START = 2020 -STOP = 2030 +START = 2017 +STOP = 2035 INTV_YEAR = 2026 +PLOT_START = 2020 # ── RR scenarios ── rr_values = { @@ -124,10 +125,23 @@ def make_menstruation(rr): def make_counterfactual(): - """Status quo throughout: 10% care, no hIUD, never shifts.""" - return HMBCounterfactual(pars=dict( + return HMBPool(pars=dict( + year=2017, + intv_year=INTV_YEAR, + care_behavior_pre=CARE_PRE, - care_behavior_post=CARE_PRE, # Same as pre — never changes + care_behavior_post=CARE_PRE, # no change at 2026 + + prob_offer_pre=0.70, + prob_offer_post=0.70, # stays the same + + tx_weights_pre=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + tx_weights_post=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + + nsaid=sc.objdict(efficacy=0.33, adherence=0.80), + txa=sc.objdict(efficacy=0.45, adherence=0.70), + pill=sc.objdict(efficacy=0.59, adherence=0.80), + hiud=sc.objdict(efficacy=0.88, adherence=1.00), )) @@ -139,7 +153,7 @@ def make_pool_intervention(hiud_scenario): scen = HIUD_SCENARIOS[hiud_scenario] return HMBPool(pars=dict( - year=2020, + year=2017, intv_year=INTV_YEAR, care_behavior_pre=CARE_PRE, care_behavior_post=CARE_POST, @@ -287,7 +301,7 @@ def plot_monthly_panels(raw, years_monthly): if idx == 0: ax.set_ylabel('Monthly anemia cases') ax.set_title(scen.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) sc.SIticks(ax=ax) @@ -333,7 +347,7 @@ def plot_annual_panels(raw, years): if idx == 0: ax.set_ylabel('Annual anemia cases') ax.set_title(scen.label) - ax.set_xlim([START - 0.5, STOP + 0.5]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START - 0.5, STOP + 0.5]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) sc.SIticks(ax=ax) @@ -377,7 +391,7 @@ def plot_hmb_anemia_panels(raw, years_monthly): if idx == 0: ax.set_ylabel('Monthly anemia cases (HMB women)') ax.set_title(scen.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) sc.SIticks(ax=ax) @@ -424,7 +438,7 @@ def plot_pct_reduction_by_hiud(raw, years): if idx == 0: ax.set_ylabel('% reduction vs counterfactual') ax.set_title(scen.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) @@ -470,7 +484,7 @@ def plot_pct_reduction_by_rr(raw, years): if idx == 0: ax.set_ylabel('% reduction vs counterfactual') ax.set_title(rr_labels[rr_name]) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) diff --git a/run_scenarios.py b/run_scenarios.py index 686e55e..92fd93e 100644 --- a/run_scenarios.py +++ b/run_scenarios.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Created on Fri Apr 24 15:18:22 2026 +Created on Wed Apr 29 15:45:02 2026 @author: navidehno """ @@ -55,10 +55,10 @@ P_HMB_PRONE = 0.53 RR_MID = 1.73 N_SEEDS = 10 -START = 2020 -STOP = 2030 +START = 2017 +STOP = 2035 INTV_YEAR = 2026 - +PLOT_START = 2020 # ── Care-seeking levels ──────────────────────────────────────────────────────── @@ -146,6 +146,13 @@ def rr_to_logistic_coeff(rr, p_base=P_BASE): return (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) +def _mask_pre_intv(years, values): + """Set values before INTV_YEAR to NaN so only post-intervention is plotted.""" + masked = np.array(values, dtype=float).copy() + masked[years < INTV_YEAR] = np.nan + return masked + + def make_menstruation(): coeff = rr_to_logistic_coeff(RR_MID) return Menstruation(pars={ @@ -167,7 +174,7 @@ def make_intervention(care_name, hiud_name): hiud_info = HIUD_LEVELS[hiud_name] return HMBPool(pars=dict( - year=2020, + year=2017, intv_year=INTV_YEAR, care_behavior_pre=CARE_PRE, @@ -187,9 +194,23 @@ def make_intervention(care_name, hiud_name): def make_counterfactual(): - return HMBCounterfactual(pars=dict( + return HMBPool(pars=dict( + year=2017, + intv_year=INTV_YEAR, + care_behavior_pre=CARE_PRE, - care_behavior_post=CARE_PRE, + care_behavior_post=CARE_PRE, # no change at 2026 + + prob_offer_pre=0.70, + prob_offer_post=0.70, # stays the same + + tx_weights_pre=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + tx_weights_post=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + + nsaid=sc.objdict(efficacy=0.33, adherence=0.80), + txa=sc.objdict(efficacy=0.45, adherence=0.70), + pill=sc.objdict(efficacy=0.59, adherence=0.80), + hiud=sc.objdict(efficacy=0.88, adherence=1.00), )) @@ -320,7 +341,8 @@ def compute_abs_reduction(raw, scen_key, outcome='anemia_hmb_annual'): # ── Plots ────────────────────────────────────────────────────────────────────── def plot_panels_by_care(raw, years_monthly, outcome, ylabel, title_prefix, filename): - """3 panels (base/mid/high care). Per panel: counterfactual + 3 hIUD lines.""" + """3 panels (base/mid/high care). Per panel: counterfactual + 3 hIUD lines. + Scenario lines are masked before INTV_YEAR so only status quo shows pre-intervention.""" fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) fig.suptitle(f'{title_prefix}\n' f'Status quo before {INTV_YEAR}, intervention after. Mid RR ({RR_MID})', @@ -329,18 +351,22 @@ def plot_panels_by_care(raw, years_monthly, outcome, ylabel, title_prefix, filen for idx, (care_name, care_info) in enumerate(CARE_LEVELS.items()): ax = axes[idx] + # Status quo: full line cf_arr = np.array(raw['counterfactual'][outcome]) ax.plot(years_monthly, cf_arr.mean(axis=0), color='#6c757d', ls='--', lw=2, label='Status quo') + # Scenario lines: post-intervention only for hiud_name, hiud_info in HIUD_LEVELS.items(): key = f'{care_name}_{hiud_name}' arr = np.array(raw[key][outcome]) mean = arr.mean(axis=0) std = arr.std(axis=0) - ax.plot(years_monthly, mean, color=CARE_COLORS[care_name][hiud_name], + mean_masked = _mask_pre_intv(years_monthly, mean) + std_masked = _mask_pre_intv(years_monthly, std) + ax.plot(years_monthly, mean_masked, color=CARE_COLORS[care_name][hiud_name], lw=1.8, label=hiud_info.label) - ax.fill_between(years_monthly, mean - std, mean + std, + ax.fill_between(years_monthly, mean_masked - std_masked, mean_masked + std_masked, color=CARE_COLORS[care_name][hiud_name], alpha=0.15) ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) @@ -348,7 +374,54 @@ def plot_panels_by_care(raw, years_monthly, outcome, ylabel, title_prefix, filen if idx == 0: ax.set_ylabel(ylabel) ax.set_title(care_info.label) - ax.set_xlim([START, STOP]) + ax.set_xlim([PLOT_START, STOP]) + ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) + ax.legend(frameon=False, fontsize=8) + if 'anemia' in outcome: + ax.set_ylim(bottom=0) + sc.SIticks(ax=ax) + + plt.tight_layout() + fig.savefig(PLOTFOLDER + filename, dpi=300, bbox_inches='tight') + print(f"Saved: {PLOTFOLDER}{filename}") + return fig + + +def plot_panels_by_hiud(raw, years_monthly, outcome, ylabel, title_prefix, filename): + """3 panels (low/mid/high hIUD). Per panel: counterfactual + 3 care lines. + Scenario lines are masked before INTV_YEAR so only status quo shows pre-intervention.""" + fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + fig.suptitle(f'{title_prefix}\n' + f'Status quo before {INTV_YEAR}, intervention after. Mid RR ({RR_MID})', + fontsize=13) + + for idx, (hiud_name, hiud_info) in enumerate(HIUD_LEVELS.items()): + ax = axes[idx] + + # Status quo: full line + cf_arr = np.array(raw['counterfactual'][outcome]) + ax.plot(years_monthly, cf_arr.mean(axis=0), color='#6c757d', ls='--', lw=2, + label='Status quo') + + # Scenario lines: post-intervention only + for care_name, care_info in CARE_LEVELS.items(): + key = f'{care_name}_{hiud_name}' + arr = np.array(raw[key][outcome]) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + mean_masked = _mask_pre_intv(years_monthly, mean) + std_masked = _mask_pre_intv(years_monthly, std) + ax.plot(years_monthly, mean_masked, color=CARE_LINE_COLORS[care_name], + lw=1.8, label=care_info.label) + ax.fill_between(years_monthly, mean_masked - std_masked, mean_masked + std_masked, + color=CARE_LINE_COLORS[care_name], alpha=0.15) + + ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) + ax.set_xlabel('Year') + if idx == 0: + ax.set_ylabel(ylabel) + ax.set_title(hiud_info.label) + ax.set_xlim([PLOT_START, STOP]) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8) if 'anemia' in outcome: @@ -389,7 +462,7 @@ def plot_pct_reduction_by_care(raw, years, outcome, ylabel, title_prefix, filena if idx == 0: ax.set_ylabel(ylabel) ax.set_title(care_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) @@ -427,7 +500,7 @@ def plot_pct_reduction_by_hiud(raw, years, outcome, ylabel, title_prefix, filena if idx == 0: ax.set_ylabel(ylabel) ax.set_title(hiud_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) @@ -475,7 +548,8 @@ def plot_heatmap(raw, years, outcome, title, filename, fmt='.1f'): def plot_combined_absolute(raw, years_monthly): - """Side-by-side: HMB prevalence and HMB-related anemia, all 10 lines.""" + """Side-by-side: HMB prevalence and HMB-related anemia, all 10 lines. + Scenario lines are masked before INTV_YEAR.""" fig, axes = plt.subplots(1, 2, figsize=(18, 6)) fig.suptitle(f'HMB prevalence and HMB-related anemia by scenario\n' f'Mid RR ({RR_MID})', fontsize=14) @@ -488,22 +562,26 @@ def plot_combined_absolute(raw, years_monthly): for idx, (outcome, title, ylabel) in enumerate(panels): ax = axes[idx] + # Status quo: full line cf = np.array(raw['counterfactual'][outcome]) ax.plot(years_monthly, cf.mean(axis=0), color='#6c757d', ls='--', lw=2.5, label='Status quo') + # Scenario lines: post-intervention only for care_name in CARE_LEVELS: for hiud_name in HIUD_LEVELS: key = f'{care_name}_{hiud_name}' arr = np.array(raw[key][outcome]) + mean = arr.mean(axis=0) + mean_masked = _mask_pre_intv(years_monthly, mean) label = f'{CARE_LEVELS[care_name].label} + {HIUD_LEVELS[hiud_name].label}' - ax.plot(years_monthly, arr.mean(axis=0), + ax.plot(years_monthly, mean_masked, color=CARE_COLORS[care_name][hiud_name], lw=1.2, label=label) ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year'); ax.set_ylabel(ylabel); ax.set_title(title) - ax.set_xlim([START, STOP]) + ax.set_xlim([PLOT_START, STOP]) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) if idx == 1: ax.set_ylim(bottom=0); sc.SIticks(ax=ax) @@ -611,6 +689,17 @@ def print_summary(raw, years): title_prefix='HMB-related anemia by scenario', filename='anemia_hmb_by_care.png') + # ── Absolute plots by hIUD level ── + plot_panels_by_hiud(raw, years_monthly, + outcome='hmb_prev_monthly', ylabel='HMB prevalence', + title_prefix='HMB prevalence by scenario', + filename='hmb_prev_by_hiud.png') + + plot_panels_by_hiud(raw, years_monthly, + outcome='anemia_hmb_monthly', ylabel='Monthly anemia cases (HMB women)', + title_prefix='HMB-related anemia by scenario', + filename='anemia_hmb_by_hiud.png') + # ── % reduction by care level ── plot_pct_reduction_by_care(raw, years_full, outcome='hmb_prev_annual', ylabel='% reduction in HMB prevalence', diff --git a/stats_interventions2.py b/stats_interventions2.py index 7a98ca8..b4dad24 100644 --- a/stats_interventions2.py +++ b/stats_interventions2.py @@ -41,9 +41,10 @@ P_HMB_PRONE = 0.53 RR_MID = 1.73 N_SEEDS = 10 -START = 2020 -STOP = 2030 +START = 2017 +STOP = 2035 INTV_YEAR = 2026 +PLOT_START = 2020 # ── Care-seeking levels (same as run_scenarios.py) ───────────────────────────── @@ -380,7 +381,7 @@ def make_intervention(care_name, hiud_name): hiud_info = HIUD_LEVELS[hiud_name] return HMBPool(pars=dict( - year=2020, + year=2017, intv_year=INTV_YEAR, care_behavior_pre=CARE_PRE, care_behavior_post=care_info.care, @@ -396,9 +397,23 @@ def make_intervention(care_name, hiud_name): def make_counterfactual(): - return HMBCounterfactual(pars=dict( + return HMBPool(pars=dict( + year=2017, + intv_year=INTV_YEAR, + care_behavior_pre=CARE_PRE, - care_behavior_post=CARE_PRE, + care_behavior_post=CARE_PRE, # no change at 2026 + + prob_offer_pre=0.70, + prob_offer_post=0.70, # stays the same + + tx_weights_pre=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + tx_weights_post=sc.objdict(nsaid=50.0, txa=25.0, pill=25.0, hiud=0.0), + + nsaid=sc.objdict(efficacy=0.33, adherence=0.80), + txa=sc.objdict(efficacy=0.45, adherence=0.70), + pill=sc.objdict(efficacy=0.59, adherence=0.80), + hiud=sc.objdict(efficacy=0.88, adherence=1.00), )) @@ -571,7 +586,7 @@ def plot_care_seeking_breakdown(raw, years_monthly): if idx == 0: ax.set_ylabel('% of HMB women seeking care') ax.set_title(care_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8, loc='upper left') @@ -609,7 +624,7 @@ def plot_care_seeker_composition(raw, years_monthly): if idx == 0: ax.set_ylabel('% of care seekers') ax.set_title(care_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(0, 100) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(0, 100) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8, loc='lower right') @@ -650,7 +665,7 @@ def plot_treatment_distribution(raw, years_monthly): if idx == 0: ax.set_ylabel('% of HMB women on treatment') ax.set_title(care_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(0, 30) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(0, 30) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8, loc='upper left') @@ -689,7 +704,7 @@ def plot_treatment_by_hiud(raw, years_monthly): if idx == 0: ax.set_ylabel('% of HMB women on treatment') ax.set_title(hiud_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(0, 30) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(0, 30) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8, loc='upper left') @@ -759,7 +774,7 @@ def plot_gave_up_hysterectomy(raw, years_monthly): if idx == 0: ax.set_ylabel('Cumulative count (millions)') ax.set_title(care_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=9) sc.SIticks(ax=ax) @@ -800,7 +815,7 @@ def plot_cumulative_person_months(raw, years_monthly): if idx == 0: ax.set_ylabel('Cumulative person-months (millions)') ax.set_title(care_info.label) - ax.set_xlim([START, STOP]); ax.set_ylim(bottom=0) + ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.legend(frameon=False, fontsize=8, loc='upper left') From ee3a6ede8b3fbf4513ba715126df2e88cf777396 Mon Sep 17 00:00:00 2001 From: nnoori-IDM <42287387+nnoori-IDM@users.noreply.github.com> Date: Thu, 30 Apr 2026 10:33:03 -0700 Subject: [PATCH 5/5] run the analysis longer and update the plots --- run_anemia_risk_sensitivity.py | 130 +++++++++++++++++---------------- 1 file changed, 67 insertions(+), 63 deletions(-) diff --git a/run_anemia_risk_sensitivity.py b/run_anemia_risk_sensitivity.py index 8eb7e61..ef37fd6 100644 --- a/run_anemia_risk_sensitivity.py +++ b/run_anemia_risk_sensitivity.py @@ -6,11 +6,11 @@ 3 RR lines per panel + counterfactual (dashed) per RR All intervention sims: - 2020–2025: status quo (10% care-seeking, NSAID/TXA/Pill, no hIUD) - 2026–2030: mid care-seeking (20%) + hIUD at specified uptake level + 2017–2025: status quo (10% care-seeking, NSAID/TXA/Pill, no hIUD) + 2026–2035: mid care-seeking (20%) + hIUD at specified uptake level Counterfactual sims: - 2020–2030: status quo throughout (10% care-seeking, no hIUD) + 2017–2035: status quo throughout (10% care-seeking, no hIUD) RR varies from start of simulation (disease parameter). """ @@ -112,6 +112,13 @@ def rr_to_logistic_coeff(rr, p_base=P_BASE): return (-np.log(1 / p_hmb - 1)) - (-np.log(1 / p_base - 1)) +def _mask_pre_intv(years, values): + """Set values before INTV_YEAR to NaN so only post-intervention is plotted.""" + masked = np.array(values, dtype=float).copy() + masked[years < INTV_YEAR] = np.nan + return masked + + def make_menstruation(rr): coeff = rr_to_logistic_coeff(rr) return Menstruation(pars={ @@ -147,7 +154,7 @@ def make_counterfactual(): def make_pool_intervention(hiud_scenario): """ - 2020–2025: status quo (10% care, no hIUD, 70% receipt) + 2017–2025: status quo (10% care, no hIUD, 70% receipt) 2026+: mid care-seeking (20%) + hIUD at specified level """ scen = HIUD_SCENARIOS[hiud_scenario] @@ -157,7 +164,7 @@ def make_pool_intervention(hiud_scenario): intv_year=INTV_YEAR, care_behavior_pre=CARE_PRE, care_behavior_post=CARE_POST, - prob_offer_post=scen.prob_offer_post, + prob_offer_post=scen.prob_offer_post, tx_weights_pre=sc.objdict(nsaid=0.50, txa=0.25, pill=0.25, hiud=0.0), tx_weights_post=scen.tx_weights_post, nsaid=sc.objdict(efficacy=0.33, adherence=0.80), @@ -250,8 +257,6 @@ def run_sensitivity(force_rerun=True): def compute_pct_reduction(raw, rr_name, hiud_name): """ % reduction in annual anemia: intervention vs counterfactual at the same RR. - - Computed per-seed then aggregated, so stochastic noise cancels. """ cf = np.array(raw[rr_name]['counterfactual']['annual']) intv = np.array(raw[rr_name][hiud_name]['annual']) @@ -264,16 +269,29 @@ def compute_pct_reduction(raw, rr_name, hiud_name): } +def compute_pct_reduction_hmb(raw, rr_name, hiud_name): + """% reduction in HMB-specific anemia (where the signal lives).""" + cf = np.array([_annualize(m) for m in raw[rr_name]['counterfactual']['hmb_monthly']]) + intv = np.array([_annualize(m) for m in raw[rr_name][hiud_name]['hmb_monthly']]) + averted = cf - intv + pct = np.where(cf > 0, averted / cf * 100, np.nan) + return { + 'mean': np.nanmean(pct, axis=0), + 'lower': np.nanpercentile(pct, 2.5, axis=0), + 'upper': np.nanpercentile(pct, 97.5, axis=0), + } + + # ── Plots ────────────────────────────────────────────────────────────────────── def plot_monthly_panels(raw, years_monthly): """ - 3 panels (one per hIUD level). Per panel: 3 RR lines (solid) - + 3 counterfactual lines (dashed, same color). + 3 panels (one per hIUD level). Per panel: 3 RR counterfactual lines (dashed, full) + + 3 RR intervention lines (solid, post-2026 only). """ fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) fig.suptitle('Monthly anemia cases by hIUD uptake\n' - 'Solid = intervention, dashed = counterfactual (no hIUD)', + 'Solid = intervention (post-2026), dashed = status quo', fontsize=13) for idx, (hiud_name, scen) in enumerate(HIUD_SCENARIOS.items()): @@ -282,19 +300,22 @@ def plot_monthly_panels(raw, years_monthly): for rr_name in rr_values: color = rr_colors[rr_name] - # Counterfactual (dashed) + # Counterfactual (dashed, full timeline) cf = np.array(raw[rr_name]['counterfactual']['monthly']) cf_mean = cf.mean(axis=0) - ax.plot(years_monthly, cf_mean, color=color, ls='--', lw=1.2, alpha=0.7) + ax.plot(years_monthly, cf_mean, color=color, ls='--', lw=1.2, alpha=0.7, + label=f'{rr_labels[rr_name]} (status quo)') - # Intervention (solid) + # Intervention (solid, post-2026 only) intv = np.array(raw[rr_name][hiud_name]['monthly']) intv_mean = intv.mean(axis=0) intv_std = intv.std(axis=0) - ax.plot(years_monthly, intv_mean, color=color, lw=1.5, - label=rr_labels[rr_name]) - ax.fill_between(years_monthly, intv_mean - intv_std, - intv_mean + intv_std, color=color, alpha=0.12) + mean_masked = _mask_pre_intv(years_monthly, intv_mean) + std_masked = _mask_pre_intv(years_monthly, intv_std) + ax.plot(years_monthly, mean_masked, color=color, lw=1.5, + label=f'{rr_labels[rr_name]} (intv)') + ax.fill_between(years_monthly, mean_masked - std_masked, + mean_masked + std_masked, color=color, alpha=0.12) ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year') @@ -303,7 +324,7 @@ def plot_monthly_panels(raw, years_monthly): ax.set_title(scen.label) ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) + ax.legend(frameon=False, fontsize=7) sc.SIticks(ax=ax) plt.tight_layout() @@ -315,32 +336,38 @@ def plot_monthly_panels(raw, years_monthly): def plot_annual_panels(raw, years): """ - 3 panels (one per hIUD level). Per panel: 3 RR lines (solid) - + 3 counterfactual lines (dashed). + 3 panels (one per hIUD level). Per panel: 3 RR counterfactual lines (dashed, full) + + 3 RR intervention lines (solid, post-2026 only). """ fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) fig.suptitle('Annual anemia cases by hIUD uptake\n' - 'Solid = intervention, dashed = counterfactual', + 'Solid = intervention (post-2026), dashed = status quo', fontsize=13) + post_mask = years >= INTV_YEAR + for idx, (hiud_name, scen) in enumerate(HIUD_SCENARIOS.items()): ax = axes[idx] for rr_name in rr_values: color = rr_colors[rr_name] - # Counterfactual + # Counterfactual (dashed, full timeline) cf = np.array(raw[rr_name]['counterfactual']['annual']) cf_mean = cf.mean(axis=0) - ax.plot(years, cf_mean, color=color, ls='--', lw=1.5, alpha=0.7) + ax.plot(years, cf_mean, color=color, ls='--', lw=1.5, alpha=0.7, + label=f'{rr_labels[rr_name]} (status quo)') - # Intervention + # Intervention (solid, post-2026 only) intv = np.array(raw[rr_name][hiud_name]['annual']) intv_mean = intv.mean(axis=0) intv_std = intv.std(axis=0) - ax.errorbar(years, intv_mean, yerr=intv_std, color=color, + mean_masked = np.where(post_mask, intv_mean, np.nan) + std_masked = np.where(post_mask, intv_std, np.nan) + ax.errorbar(years[post_mask], intv_mean[post_mask], + yerr=intv_std[post_mask], color=color, lw=2, marker='o', capsize=3, markersize=4, - label=rr_labels[rr_name]) + label=f'{rr_labels[rr_name]} (intv)') ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year') @@ -349,7 +376,7 @@ def plot_annual_panels(raw, years): ax.set_title(scen.label) ax.set_xlim([PLOT_START - 0.5, STOP + 0.5]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) + ax.legend(frameon=False, fontsize=7) sc.SIticks(ax=ax) plt.tight_layout() @@ -363,7 +390,7 @@ def plot_hmb_anemia_panels(raw, years_monthly): """3 panels: monthly anemia among HMB women, with counterfactual.""" fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) fig.suptitle('Anemia among HMB women by hIUD uptake\n' - 'Solid = intervention, dashed = counterfactual', + 'Solid = intervention (post-2026), dashed = status quo', fontsize=13) for idx, (hiud_name, scen) in enumerate(HIUD_SCENARIOS.items()): @@ -372,19 +399,22 @@ def plot_hmb_anemia_panels(raw, years_monthly): for rr_name in rr_values: color = rr_colors[rr_name] - # Counterfactual + # Counterfactual (dashed, full timeline) cf = np.array(raw[rr_name]['counterfactual']['hmb_monthly']) cf_mean = cf.mean(axis=0) - ax.plot(years_monthly, cf_mean, color=color, ls='--', lw=1.2, alpha=0.7) + ax.plot(years_monthly, cf_mean, color=color, ls='--', lw=1.2, alpha=0.7, + label=f'{rr_labels[rr_name]} (status quo)') - # Intervention + # Intervention (solid, post-2026 only) intv = np.array(raw[rr_name][hiud_name]['hmb_monthly']) intv_mean = intv.mean(axis=0) intv_std = intv.std(axis=0) - ax.plot(years_monthly, intv_mean, color=color, lw=1.5, - label=rr_labels[rr_name]) - ax.fill_between(years_monthly, intv_mean - intv_std, - intv_mean + intv_std, color=color, alpha=0.12) + mean_masked = _mask_pre_intv(years_monthly, intv_mean) + std_masked = _mask_pre_intv(years_monthly, intv_std) + ax.plot(years_monthly, mean_masked, color=color, lw=1.5, + label=f'{rr_labels[rr_name]} (intv)') + ax.fill_between(years_monthly, mean_masked - std_masked, + mean_masked + std_masked, color=color, alpha=0.12) ax.axvline(INTV_YEAR, color='k', ls='--', lw=1.5) ax.set_xlabel('Year') @@ -393,7 +423,7 @@ def plot_hmb_anemia_panels(raw, years_monthly): ax.set_title(scen.label) ax.set_xlim([PLOT_START, STOP]); ax.set_ylim(bottom=0) ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) - ax.legend(frameon=False, fontsize=9) + ax.legend(frameon=False, fontsize=7) sc.SIticks(ax=ax) plt.tight_layout() @@ -406,9 +436,6 @@ def plot_hmb_anemia_panels(raw, years_monthly): def plot_pct_reduction_by_hiud(raw, years): """ 3 panels (one per hIUD level): % reduction vs counterfactual, by RR. - - This is the key result plot. Shows how much anemia is averted - by introducing hIUD, and how that depends on the RR assumption. """ post_mask = years >= INTV_YEAR @@ -421,7 +448,6 @@ def plot_pct_reduction_by_hiud(raw, years): ax = axes[idx] for rr_name in rr_values: - s = compute_pct_reduction(raw, rr_name, hiud_name) s = compute_pct_reduction_hmb(raw, rr_name, hiud_name) mean = np.where(post_mask, s['mean'], np.nan) @@ -452,9 +478,6 @@ def plot_pct_reduction_by_hiud(raw, years): def plot_pct_reduction_by_rr(raw, years): """ 3 panels (one per RR): % reduction vs counterfactual, by hIUD level. - - Alternative view: for a given RR assumption, how much does - increasing hIUD uptake help? """ post_mask = years >= INTV_YEAR @@ -467,7 +490,6 @@ def plot_pct_reduction_by_rr(raw, years): ax = axes[idx] for hiud_name, scen in HIUD_SCENARIOS.items(): - s = compute_pct_reduction(raw, rr_name, hiud_name) s = compute_pct_reduction_hmb(raw, rr_name, hiud_name) mean = np.where(post_mask, s['mean'], np.nan) @@ -495,19 +517,6 @@ def plot_pct_reduction_by_rr(raw, years): return fig -def compute_pct_reduction_hmb(raw, rr_name, hiud_name): - """% reduction in HMB-specific anemia (where the signal lives).""" - cf = np.array([_annualize(m) for m in raw[rr_name]['counterfactual']['hmb_monthly']]) - intv = np.array([_annualize(m) for m in raw[rr_name][hiud_name]['hmb_monthly']]) - averted = cf - intv - pct = np.where(cf > 0, averted / cf * 100, np.nan) - return { - 'mean': np.nanmean(pct, axis=0), - 'lower': np.nanpercentile(pct, 2.5, axis=0), - 'upper': np.nanpercentile(pct, 97.5, axis=0), - } - - # ── Summary table ────────────────────────────────────────────────────────────── def print_summary(raw, years): post_mask = years >= INTV_YEAR @@ -546,7 +555,6 @@ def print_summary(raw, years): print(f" {'─'*50}") for rr_name in rr_values: - s = compute_pct_reduction(raw, rr_name, hiud_name) s = compute_pct_reduction_hmb(raw, rr_name, hiud_name) m = np.nanmean(s['mean'][post_mask]) lo = np.nanmean(s['lower'][post_mask]) @@ -577,8 +585,4 @@ def print_summary(raw, years): plot_pct_reduction_by_rr(raw, years_full) # Summary - print_summary(raw, years_full) - - raw = sc.loadobj('results_anemia_sa/anemia_sa_hiud_rr_cf_raw.obj') - plot_pct_reduction_by_hiud(raw, years_full) - plot_pct_reduction_by_rr(raw, years_full) + print_summary(raw, years_full) \ No newline at end of file