diff --git a/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv b/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv index c6bd6414e7..35ca5e32d1 100644 --- a/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv +++ b/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv @@ -26,3 +26,5 @@ year_use_funded_or_actual_staffing_switch,2100 cons_override_treatment_ids,[] cons_override_treatment_ids_prob_avail,1.0 clinic_configuration_name,Default +year_HR_scaling_by_district_and_officer_type,2100 +HR_scaling_by_district_and_officer_type_mode,default diff --git a/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/default.csv b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/default.csv new file mode 100644 index 0000000000..c10498edb8 --- /dev/null +++ b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/default.csv @@ -0,0 +1,38 @@ +District,Clinical,Nursing_and_Midwifery,Pharmacy,DCSA,Dental,Laboratory,Mental,Nutrition,Radiography +Balaka,1,1,1,1,1,1,1,1,1 +Blantyre,1,1,1,1,1,1,1,1,1 +Blantyre City,1,1,1,1,1,1,1,1,1 +Referral Hospital_Central,1,1,1,1,1,1,1,1,1 +Referral Hospital_Northern,1,1,1,1,1,1,1,1,1 +Referral Hospital_Southern,1,1,1,1,1,1,1,1,1 +Chikwawa,1,1,1,1,1,1,1,1,1 +Chiradzulu,1,1,1,1,1,1,1,1,1 +Chitipa,1,1,1,1,1,1,1,1,1 +Dedza,1,1,1,1,1,1,1,1,1 +Dowa,1,1,1,1,1,1,1,1,1 +Headquarter,1,1,1,1,1,1,1,1,1 +Karonga,1,1,1,1,1,1,1,1,1 +Kasungu,1,1,1,1,1,1,1,1,1 +Likoma,1,1,1,1,1,1,1,1,1 +Lilongwe,1,1,1,1,1,1,1,1,1 +Lilongwe City,1,1,1,1,1,1,1,1,1 +Machinga,1,1,1,1,1,1,1,1,1 +Mangochi,1,1,1,1,1,1,1,1,1 +Mchinji,1,1,1,1,1,1,1,1,1 +Mulanje,1,1,1,1,1,1,1,1,1 +Mwanza,1,1,1,1,1,1,1,1,1 +Mzimba,1,1,1,1,1,1,1,1,1 +Mzuzu City,1,1,1,1,1,1,1,1,1 +Neno,1,1,1,1,1,1,1,1,1 +Nkhata Bay,1,1,1,1,1,1,1,1,1 +Nkhotakota,1,1,1,1,1,1,1,1,1 +Nsanje,1,1,1,1,1,1,1,1,1 +Ntcheu,1,1,1,1,1,1,1,1,1 +Ntchisi,1,1,1,1,1,1,1,1,1 +Phalombe,1,1,1,1,1,1,1,1,1 +Rumphi,1,1,1,1,1,1,1,1,1 +Salima,1,1,1,1,1,1,1,1,1 +Thyolo,1,1,1,1,1,1,1,1,1 +Zomba,1,1,1,1,1,1,1,1,1 +Zomba City,1,1,1,1,1,1,1,1,1 +Zomba Mental Hospital,1,1,1,1,1,1,1,1,1 diff --git a/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/establishment_by_district_and_CNP.csv b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/establishment_by_district_and_CNP.csv new file mode 100644 index 0000000000..cf937973ce --- /dev/null +++ b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/establishment_by_district_and_CNP.csv @@ -0,0 +1,38 @@ +District,Clinical,Nursing_and_Midwifery,Pharmacy,DCSA,Dental,Laboratory,Mental,Nutrition,Radiography +Balaka,1.174526999,1.169316612,2.052879982,1,1,1,1,1,1 +Blantyre,0.759811383,0.89730864,1.578474944,1,1,1,1,1,1 +Blantyre City,0.759811383,0.89730864,1.578474944,1,1,1,1,1,1 +Referral Hospital_Central,2.486954763,1.355424224,1.412864458,1,1,1,1,1,1 +Referral Hospital_Northern,2.197010075,1.070124208,0.897056799,1,1,1,1,1,1 +Referral Hospital_Southern,2.177651732,1.402012229,0.710875199,1,1,1,1,1,1 +Chikwawa,1.571206225,1.98442501,1.947217042,1,1,1,1,1,1 +Chiradzulu,1.508760236,1.900692263,2.173637628,1,1,1,1,1,1 +Chitipa,1.145249933,1.936513522,1.690607044,1,1,1,1,1,1 +Dedza,1.23965271,1.543188619,2.264205862,1,1,1,1,1,1 +Dowa,1.435421075,1.737431437,1.792297693,1,1,1,1,1,1 +Headquarter,1.183424931,3.62272938,1.449091752,1,1,1,1,1,1 +Karonga,1.335881459,1.872527654,3.043092679,1,1,1,1,1,1 +Kasungu,1.482635542,1.937984358,2.557220739,1,1,1,1,1,1 +Likoma,2.233502608,3.155265589,3.042577533,1,1,1,1,1,1 +Lilongwe,1.380087383,1.167518839,1.709475426,1,1,1,1,1,1 +Lilongwe City,1.380087383,1.167518839,1.709475426,1,1,1,1,1,1 +Machinga,1.584340315,1.853985036,1.96662452,1,1,1,1,1,1 +Mangochi,1.212312054,1.260791694,2.940803849,1,1,1,1,1,1 +Mchinji,1.319220321,1.722723439,1.716029706,1,1,1,1,1,1 +Mulanje,1.442124965,1.488502197,3.550274792,1,1,1,1,1,1 +Mwanza,2.032262823,1.646695173,1.368586655,1,1,1,1,1,1 +Mzimba,1.664191309,1.462212301,2.258878319,1,1,1,1,1,1 +Mzuzu City,1.664191309,1.462212302,2.258878319,1,1,1,1,1,1 +Neno,1.90022409,2.133827415,2.264205862,1,1,1,1,1,1 +Nkhata Bay,1.64711602,2.260579269,2.781346584,1,1,1,1,1,1 +Nkhotakota,1.798270487,1.964020582,2.225390905,1,1,1,1,1,1 +Nsanje,1.560560348,2.028728453,2.475531743,1,1,1,1,1,1 +Ntcheu,1.436599582,1.544519745,2.753274328,1,1,1,1,1,1 +Ntchisi,1.511195684,1.885047321,2.041902014,1,1,1,1,1,1 +Phalombe,1.81136469,1.561895301,1.783497541,1,1,1,1,1,1 +Rumphi,1.680946432,1.643251104,1.494375869,1,1,1,1,1,1 +Salima,1.156345943,1.449091752,2.371241048,1,1,1,1,1,1 +Thyolo,1.283305831,1.360371849,2.083069393,1,1,1,1,1,1 +Zomba,1.12943916,1.10442287,1.332229514,1,1,1,1,1,1 +Zomba City,1.12943916,1.10442287,1.332229514,1,1,1,1,1,1 +Zomba Mental Hospital,1.521546339,0.85520169,2.173637628,1,1,1,1,1,1 diff --git a/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/establishment_by_district_and_N.csv b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/establishment_by_district_and_N.csv new file mode 100644 index 0000000000..dcdf638838 --- /dev/null +++ b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_district_and_officer_type/establishment_by_district_and_N.csv @@ -0,0 +1,38 @@ +District,Clinical,Nursing_and_Midwifery,Pharmacy,DCSA,Dental,Laboratory,Mental,Nutrition,Radiography +Balaka,1,1.169316612,1,1,1,1,1,1,1 +Blantyre,1,0.89730864,1,1,1,1,1,1,1 +Blantyre City,1,0.89730864,1,1,1,1,1,1,1 +Referral Hospital_Central,1,1.355424224,1,1,1,1,1,1,1 +Referral Hospital_Northern,1,1.070124208,1,1,1,1,1,1,1 +Referral Hospital_Southern,1,1.402012229,1,1,1,1,1,1,1 +Chikwawa,1,1.98442501,1,1,1,1,1,1,1 +Chiradzulu,1,1.900692263,1,1,1,1,1,1,1 +Chitipa,1,1.936513522,1,1,1,1,1,1,1 +Dedza,1,1.543188619,1,1,1,1,1,1,1 +Dowa,1,1.737431437,1,1,1,1,1,1,1 +Headquarter,1,3.62272938,1,1,1,1,1,1,1 +Karonga,1,1.872527654,1,1,1,1,1,1,1 +Kasungu,1,1.937984358,1,1,1,1,1,1,1 +Likoma,1,3.155265589,1,1,1,1,1,1,1 +Lilongwe,1,1.167518839,1,1,1,1,1,1,1 +Lilongwe City,1,1.167518839,1,1,1,1,1,1,1 +Machinga,1,1.853985036,1,1,1,1,1,1,1 +Mangochi,1,1.260791694,1,1,1,1,1,1,1 +Mchinji,1,1.722723439,1,1,1,1,1,1,1 +Mulanje,1,1.488502197,1,1,1,1,1,1,1 +Mwanza,1,1.646695173,1,1,1,1,1,1,1 +Mzimba,1,1.462212301,1,1,1,1,1,1,1 +Mzuzu City,1,1.462212302,1,1,1,1,1,1,1 +Neno,1,2.133827415,1,1,1,1,1,1,1 +Nkhata Bay,1,2.260579269,1,1,1,1,1,1,1 +Nkhotakota,1,1.964020582,1,1,1,1,1,1,1 +Nsanje,1,2.028728453,1,1,1,1,1,1,1 +Ntcheu,1,1.544519745,1,1,1,1,1,1,1 +Ntchisi,1,1.885047321,1,1,1,1,1,1,1 +Phalombe,1,1.561895301,1,1,1,1,1,1,1 +Rumphi,1,1.643251104,1,1,1,1,1,1,1 +Salima,1,1.449091752,1,1,1,1,1,1,1 +Thyolo,1,1.360371849,1,1,1,1,1,1,1 +Zomba,1,1.10442287,1,1,1,1,1,1,1 +Zomba City,1,1.10442287,1,1,1,1,1,1,1 +Zomba Mental Hospital,1,0.85520169,1,1,1,1,1,1,1 diff --git a/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/establishment_staffing_CNP.csv b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/establishment_staffing_CNP.csv new file mode 100644 index 0000000000..4577899122 --- /dev/null +++ b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/establishment_staffing_CNP.csv @@ -0,0 +1,10 @@ +Officer_Category,L0_factor,L1a_factor,L1b_factor,L2_factor,L3_factor,L4_factor,L5_factor +Clinical,1.536932742,1.536932742,1.536932742,1.536932742,1.536932742,1.536932742,1.536932742 +DCSA,1,1,1,1,1,1,1 +Dental,1,1,1,1,1,1,1 +Laboratory,1,1,1,1,1,1,1 +Mental,1,1,1,1,1,1,1 +Nursing_and_Midwifery,1.455369535,1.455369535,1.455369535,1.455369535,1.455369535,1.455369535,1.455369535 +Nutrition,1,1,1,1,1,1,1 +Pharmacy,1.855698791,1.855698791,1.855698791,1.855698791,1.855698791,1.855698791,1.855698791 +Radiography,1,1,1,1,1,1,1 diff --git a/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/establishment_staffing_N.csv b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/establishment_staffing_N.csv new file mode 100644 index 0000000000..64ca01af65 --- /dev/null +++ b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/establishment_staffing_N.csv @@ -0,0 +1,10 @@ +Officer_Category,L0_factor,L1a_factor,L1b_factor,L2_factor,L3_factor,L4_factor,L5_factor +Clinical,1,1,1,1,1,1,1 +DCSA,1,1,1,1,1,1,1 +Dental,1,1,1,1,1,1,1 +Laboratory,1,1,1,1,1,1,1 +Mental,1,1,1,1,1,1,1 +Nursing_and_Midwifery,1.455369535,1.455369535,1.455369535,1.455369535,1.455369535,1.455369535,1.455369535 +Nutrition,1,1,1,1,1,1,1 +Pharmacy,1,1,1,1,1,1,1 +Radiography,1,1,1,1,1,1,1 diff --git a/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/worse_staffing_N.csv b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/worse_staffing_N.csv new file mode 100644 index 0000000000..be1dc97f7f --- /dev/null +++ b/resources/healthsystem/human_resources/scaling_capabilities/ResourceFile_HR_scaling_by_level_and_officer_type/worse_staffing_N.csv @@ -0,0 +1,10 @@ +Officer_Category,L0_factor,L1a_factor,L1b_factor,L2_factor,L3_factor,L4_factor,L5_factor +Clinical,1,1,1,1,1,1,1 +DCSA,1,1,1,1,1,1,1 +Dental,1,1,1,1,1,1,1 +Laboratory,1,1,1,1,1,1,1 +Mental,1,1,1,1,1,1,1 +Nursing_and_Midwifery,0.85,0.85,0.85,0.85,0.85,0.85,0.85 +Nutrition,1,1,1,1,1,1,1 +Pharmacy,1,1,1,1,1,1,1 +Radiography,1,1,1,1,1,1,1 diff --git a/src/scripts/nurses_analyses/analysis_nurses_scenario.py b/src/scripts/nurses_analyses/analysis_nurses_scenario.py new file mode 100644 index 0000000000..5ae3146295 --- /dev/null +++ b/src/scripts/nurses_analyses/analysis_nurses_scenario.py @@ -0,0 +1,215 @@ +"""This file uses the results of the results of running `nurse_analyses/nurses_scenario_analyses.py` to make some summary + graphs.""" + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from scripts.nurses_analyses.nurses_scenario_analyses import StaffingScenario +from tlo.analysis.utils import ( + extract_results, + get_scenario_info, + load_pickled_dataframes, + make_age_grp_lookup, + make_age_grp_types, + summarize, +) + + +# Rename draw numbers to scenario names +def set_param_names_as_column_index_level_0(_df, param_names): + """Set column index level 0 (draw numbers) to scenario names.""" + ordered_param_names = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ + ordered_param_names.get(col) + for col in _df.columns.levels[0] + ] + _df.columns = _df.columns.set_levels(names_of_cols_level0, level=0) + return _df + + +def extract_total_deaths(results_folder): + def extract_deaths_total(df: pd.DataFrame) -> pd.Series: + return pd.Series({"Total": len(df)}) + + return extract_results( + results_folder, + module="tlo.methods.demography", + key="death", + custom_generate_series=extract_deaths_total, + do_scaling=True + ) + + +def plot_summarized_total_deaths(summarized_total_deaths): + fig, ax = plt.subplots() + + scenario_names = summarized_total_deaths.columns.get_level_values(0).unique() + + means = np.array([ + summarized_total_deaths[(s, "mean")].values[0] + for s in scenario_names + ]) + lowers = np.array([ + summarized_total_deaths[(s, "lower")].values[0] + for s in scenario_names + ]) + uppers = np.array([ + summarized_total_deaths[(s, "upper")].values[0] + for s in scenario_names + ]) + + ax.bar( + scenario_names, + means, + yerr=[means - lowers, uppers - means], + capsize=5 + ) + + ax.set_ylabel("Total number of deaths") + ax.set_xticklabels(scenario_names, rotation=45, ha="right") + fig.tight_layout() + + return fig, ax + + +def compute_difference_in_deaths_across_runs(total_deaths, scenario_info): + deaths_difference_by_run = [ + total_deaths[0][run_number]["Total"] - total_deaths[1][run_number]["Total"] + for run_number in range(scenario_info["runs_per_draw"]) + ] + return np.mean(deaths_difference_by_run) + + +def extract_deaths_by_age(results_folder): + def extract_deaths_by_age_group(df: pd.DataFrame) -> pd.Series: + _, age_group_lookup = make_age_grp_lookup() + df["Age_Grp"] = df["age"].map(age_group_lookup).astype(make_age_grp_types()) + df = df.rename(columns={"sex": "Sex"}) + return df.groupby(["Age_Grp"])["person_id"].count() + + return extract_results( + results_folder, + module="tlo.methods.demography", + key="death", + custom_generate_series=extract_deaths_by_age_group, + do_scaling=True + ) + + +def plot_summarized_deaths_by_age(deaths_summarized_by_age): + fig, ax = plt.subplots() + + scenario_names = deaths_summarized_by_age.columns.get_level_values(0).unique() + + for i, scenario in enumerate(scenario_names): + central_values = deaths_summarized_by_age[(scenario, "mean")].values + lower_values = deaths_summarized_by_age[(scenario, "lower")].values + upper_values = deaths_summarized_by_age[(scenario, "upper")].values + + ax.plot( + deaths_summarized_by_age.index, + central_values, + label=scenario + ) + + ax.fill_between( + deaths_summarized_by_age.index, + lower_values, + upper_values, + alpha=0.3 + ) + + ax.set(xlabel="Age-Group", ylabel="Total deaths") + ax.set_xticks(deaths_summarized_by_age.index) + ax.set_xticklabels(deaths_summarized_by_age.index, rotation=90) + ax.legend() + fig.tight_layout() + return fig, ax + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + "Analyse scenario results for nurses scenario" + ) + parser.add_argument( + "--scenario-outputs-folder", + type=Path, + required=True, + help="Path to folder containing scenario outputs", + ) + parser.add_argument( + "--show-figures", + action="store_true", + help="Whether to interactively show figures", + ) + parser.add_argument( + "--save-figures", + action="store_true", + help="Whether to save figures to results folder", + ) + args = parser.parse_args() + + # results_folder = args.scenario_outputs_folder + + results_folder = Path( + './outputs/wamulwafu@kuhes.ac.mw/nurses_scenario_outputs-2026-04-20T111238Z' + ) + + # Load log (optional, but useful) + log = load_pickled_dataframes(results_folder) + + scenario_info = get_scenario_info(results_folder) + + # Get scenario names directly from Scenario class + param_names = tuple(StaffingScenario()._scenarios.keys()) + + # Keep only scenarios with Default Healthsystem Function + default_hs_scenarios = [ + "Baseline Nurses / Default Healthsystem Function", + "Fewer Nurses / Default Healthsystem Function", + "More Nurses / Default Healthsystem Function", + ] + + # Total deaths + total_deaths = extract_total_deaths(results_folder).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names + ) + + summarized_total_deaths = summarize(total_deaths) + + # Filter to Default Healthsystem Function scenarios only + summarized_total_deaths = summarized_total_deaths.loc[ + :, + summarized_total_deaths.columns.get_level_values(0).isin(default_hs_scenarios) + ] + + fig_1, ax_1 = plot_summarized_total_deaths(summarized_total_deaths) + + # Deaths by age + deaths_by_age = extract_deaths_by_age(results_folder).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names + ) + + summarized_deaths_by_age = summarize(deaths_by_age) + + # Filter to Default Healthsystem Function scenarios only + summarized_deaths_by_age = summarized_deaths_by_age.loc[ + :, + summarized_deaths_by_age.columns.get_level_values(0).isin(default_hs_scenarios) + ] + + fig_2, ax_2 = plot_summarized_deaths_by_age(summarized_deaths_by_age) + + if args.show_figures: + plt.show() + + if args.save_figures: + fig_1.savefig(results_folder / "total_deaths_across_scenarios.pdf") + fig_2.savefig(results_folder / "deaths_by_age_across_scenarios.pdf") diff --git a/src/scripts/nurses_analyses/analysis_nurses_scenario_dalys.py b/src/scripts/nurses_analyses/analysis_nurses_scenario_dalys.py new file mode 100644 index 0000000000..48ca6c3b3b --- /dev/null +++ b/src/scripts/nurses_analyses/analysis_nurses_scenario_dalys.py @@ -0,0 +1,1634 @@ +"""Plot DALYs and Deaths across nurse staffing scenarios. + +This script produces two figures for the Nurse Shortages analysis: + +""" + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from scripts.nurses_analyses.nurses_scenario_analyses import StaffingScenario +from tlo.analysis.utils import extract_results, load_pickled_dataframes, summarize + + +def find_difference_relative_to_comparison_series( + _ser: pd.Series, + comparison: str, + scaled: bool = False, + drop_comparison: bool = True, +): + return ( + _ser + .unstack(level=0) + .apply( + lambda x: ( + (x - x[comparison]) / + (x[comparison] if scaled else 1.0) + ), + axis=1, + ) + .drop( + columns=([comparison] if drop_comparison else []) + ) + .stack() + ) + + +def find_difference_relative_to_comparison_series_dataframe( + _df: pd.DataFrame, + **kwargs, +): + return pd.concat( + { + idx: find_difference_relative_to_comparison_series( + row, + **kwargs, + ) + for idx, row in _df.iterrows() + }, + axis=1, + ).T + + +def set_param_names_as_column_index_level_0(_df, param_names): + """Set column index level 0 (draw numbers) to scenario names.""" + ordered_param_names = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ + ordered_param_names.get(col) + for col in _df.columns.levels[0] + ] + _df.columns = _df.columns.set_levels(names_of_cols_level0, level=0) + return _df + + +def extract_annual_dalys(results_folder): + def get_num_dalys_yearly(df: pd.DataFrame) -> pd.Series: + """Return total DALYs for each year.""" + # Sum all cause columns after removing metadata columns + yearly = ( + df.drop(columns=["date", "sex", "age_range"], errors="ignore") + .groupby("year") + .sum() + .sum(axis=1) + ) + return yearly + + return extract_results( + results_folder, + module="tlo.methods.healthburden", + key="dalys_stacked", + custom_generate_series=get_num_dalys_yearly, + do_scaling=True, + ) + + +# Extract annual Deaths +def extract_annual_deaths(results_folder): + def get_num_deaths_yearly(df: pd.DataFrame) -> pd.Series: + """Return total deaths for each year.""" + yearly = ( + df.assign(year=df["date"].dt.year) + .groupby("year")["person_id"] + .count() + ) + return yearly + + return extract_results( + results_folder, + module="tlo.methods.demography", + key="death", + custom_generate_series=get_num_deaths_yearly, + do_scaling=True, + ) + + +# Plot: Annual DALYs over time +def plot_annual_dalys(summarized_annual_dalys): + fig, ax = plt.subplots(figsize=(10, 6)) + + scenario_names = summarized_annual_dalys.columns.get_level_values(0).unique() + + # Short labels for legend + label_map = { + "Baseline Nurses / Default Healthsystem Function": "Baseline", + "Fewer Nurses / Default Healthsystem Function": "Fewer nurses", + "More Nurses / Default Healthsystem Function": "More nurses", + + "Baseline Nurses / Improved Healthsystem Function": "Baseline", + "Fewer Nurses / Improved Healthsystem Function": "Fewer nurses", + "More Nurses / Improved Healthsystem Function": "More nurses", + } + + for scenario in scenario_names: + years = summarized_annual_dalys.index.astype(int) + means = summarized_annual_dalys[(scenario, "mean")].values + lowers = summarized_annual_dalys[(scenario, "lower")].values + uppers = summarized_annual_dalys[(scenario, "upper")].values + + print(means.min(), means.max()) + + ax.plot( + years, + means, + linewidth=2, + label=label_map.get(scenario, scenario), + ) + + ax.fill_between( + years, + lowers, + uppers, + alpha=0.2, + ) + + ax.set_xlabel("Year") + ax.set_ylabel("Annual DALYs") + ax.legend() + ax.grid(alpha=0.3) + ax.set_xlim(2025, 2034) + ax.set_ylim(bottom=8e6) + # ax.set_ylim(bottom=0.8) + fig.tight_layout() + + return fig, ax + + +# Plot: Annual Deaths over time +def plot_annual_deaths(summarized_annual_deaths): + fig, ax = plt.subplots(figsize=(10, 6)) + + scenario_names = ( + summarized_annual_deaths.columns + .get_level_values(0) + .unique() + ) + + label_map = { + "Baseline Nurses / Default Healthsystem Function": "Baseline", + "Fewer Nurses / Default Healthsystem Function": "Fewer nurses", + "More Nurses / Default Healthsystem Function": "More nurses", + + "Baseline Nurses / Improved Healthsystem Function": "Baseline", + "Fewer Nurses / Improved Healthsystem Function": "Fewer nurses", + "More Nurses / Improved Healthsystem Function": "More nurses", + } + + for scenario in scenario_names: + years = summarized_annual_deaths.index.astype(int) + means = summarized_annual_deaths[(scenario, "mean")].values + lowers = summarized_annual_deaths[(scenario, "lower")].values + + uppers = summarized_annual_deaths[(scenario, "upper")].values + + ax.plot( + years, + means, + linewidth=2, + label=label_map.get(scenario, scenario), + ) + + ax.fill_between( + years, + lowers, + uppers, + alpha=0.2, + ) + + ax.set_xlabel("Year") + ax.set_ylabel("Annual deaths") + ax.legend() + ax.grid(alpha=0.3) + ax.set_xlim(2025, 2034) + fig.tight_layout() + return fig, ax + + +# Extract deaths by cause +def extract_deaths_by_cause(results_folder): + def get_deaths_by_cause(df: pd.DataFrame) -> pd.Series: + """ + Return deaths by cause aggregated across 2027–2034. + """ + # Add year + df = df.assign(year=df["date"].dt.year) + # Restrict years + df = df[df["year"].between(2027, 2034)] + # Changed to "label" in order to capture group causes + # cause_col = "cause" + cause_col = "label" + deaths_by_cause = (df.groupby(cause_col)["person_id"].count()) + return deaths_by_cause + + return extract_results( + results_folder, + module="tlo.methods.demography", + key="death", + custom_generate_series=get_deaths_by_cause, + do_scaling=True, + ) + + +# Extract deaths by age group +def extract_deaths_by_age_group(results_folder): + + def get_deaths_by_age_group(df: pd.DataFrame) -> pd.Series: + """ + Return deaths by age group aggregated across 2027–2034. + """ + df = df.assign(year=df["date"].dt.year) + df = df[df["year"].between(2027, 2034)] + + # Create age groups + age_bins = [ + 0, 5, 10, 15, 20, 25, 30, 35, + 40, 45, 50, 55, 60, 65, 70, + 75, 80, np.inf + ] + + age_labels = [ + "0-4", + "5-9", + "10-14", + "15-19", + "20-24", + "25-29", + "30-34", + "35-39", + "40-44", + "45-49", + "50-54", + "55-59", + "60-64", + "65-69", + "70-74", + "75-79", + "80+", + ] + + df["age_group"] = pd.cut( + df["age"], + bins=age_bins, + labels=age_labels, + right=False, + ) + # Aggregate deaths by age group + deaths_by_age = (df.groupby("age_group")["person_id"].count()) + return deaths_by_age + + return extract_results( + results_folder, + module="tlo.methods.demography", + key="death", + custom_generate_series=get_deaths_by_age_group, + do_scaling=True, + ) + + +# Extract DALYs by cause +def extract_dalys_by_cause(results_folder): + def get_dalys_by_cause(df: pd.DataFrame) -> pd.Series: + """ + Return DALYs by cause aggregated across 2027–2034. + """ + df = df.assign(year=df["date"].dt.year) + df = df[df["year"].between(2027, 2034)] + # Removing metadata columns + metadata_cols = ["date", "sex", "age_range", "year",] + cause_cols = [c for c in df.columns if c not in metadata_cols] + # Sum DALYs for each cause + return df[cause_cols].sum() + + return extract_results( + results_folder, + module="tlo.methods.healthburden", + key="dalys_stacked", + custom_generate_series=get_dalys_by_cause, + do_scaling=True, + ) + + +# Extract DALYs by age group +def extract_dalys_by_age_group(results_folder): + + def get_dalys_by_age_group(df: pd.DataFrame) -> pd.Series: + """ + Return DALYs by age group aggregated across 2027–2034. + """ + df = df.assign(year=df["date"].dt.year) + df = df[df["year"].between(2027, 2034)] + + # Metadata columns to exclude + metadata_cols = ["date", "sex", "age_range", "year",] + # DALY cause columns + cause_cols = [c for c in df.columns if c not in metadata_cols] + + # Sum DALYs across causes first + df["total_dalys"] = df[cause_cols].sum(axis=1) + # Aggregating by age group + dalys_by_age = ( + df.groupby("age_range")["total_dalys"] + .sum() + ) + return dalys_by_age + + return extract_results( + results_folder, + module="tlo.methods.healthburden", + key="dalys_stacked", + custom_generate_series=get_dalys_by_age_group, + do_scaling=True, + ) + + +# Plot: Percent DALYs averted relative to baseline (2027–2034) +def calculate_percent_dalys_averted( + annual_dalys, + baseline_scenario, + comparison_years=range(2027, 2035), +): + """ + Calculate % DALYs averted using run-to-run differences. + """ + years = annual_dalys.index.astype(int) + year_mask = np.isin(years, list(comparison_years)) + + annual_dalys = annual_dalys.loc[year_mask] + annual_dalys_agg = annual_dalys.sum(axis=0) + + pct_diff = pd.DataFrame( + -100.0 + * find_difference_relative_to_comparison_series( + annual_dalys_agg, + comparison=baseline_scenario, + scaled=True, + ) + ).T + + summarized = summarize(pct_diff) + results = {} + + scenario_names = ( + summarized.columns + .get_level_values(0) + .unique() + ) + + for scenario in scenario_names: + results[scenario] = { + "mean": summarized[(scenario, "mean")].iloc[0], + "lower": summarized[(scenario, "lower")].iloc[0], + "upper": summarized[(scenario, "upper")].iloc[0], + } + + return pd.DataFrame(results).T + + +def calculate_percent_deaths_averted( + annual_deaths, + baseline_scenario, + comparison_years=range(2027, 2035), +): + """ + Calculate % deaths averted using run-to-run differences. + """ + years = annual_deaths.index.astype(int) + year_mask = np.isin(years, list(comparison_years)) + + annual_deaths = annual_deaths.loc[year_mask] + annual_deaths_agg = annual_deaths.sum(axis=0) + + pct_diff = pd.DataFrame( + -100.0 + * find_difference_relative_to_comparison_series( + annual_deaths_agg, + comparison=baseline_scenario, + scaled=True, + ) + ).T + + summarized = summarize(pct_diff) + results = {} + + scenario_names = ( + summarized.columns + .get_level_values(0) + .unique() + ) + + for scenario in scenario_names: + results[scenario] = { + "mean": summarized[(scenario, "mean")].iloc[0], + "lower": summarized[(scenario, "lower")].iloc[0], + "upper": summarized[(scenario, "upper")].iloc[0], + } + + return pd.DataFrame(results).T + + +# Calculate % deaths averted by cause +def calculate_percent_deaths_averted_by_cause( + deaths_by_cause, + baseline_scenario, +): + + pct_diff = ( + -100.0 + * find_difference_relative_to_comparison_series_dataframe( + deaths_by_cause, + comparison=baseline_scenario, + scaled=True, + ) + ) + + summarized = summarize(pct_diff) + results = {} + scenario_names = (summarized.columns.get_level_values(0).unique()) + + for scenario in scenario_names: + results[scenario] = pd.DataFrame({ + "mean": summarized[(scenario, "mean")], + "lower": summarized[(scenario, "lower")], + "upper": summarized[(scenario, "upper")], + }) + + return results + + +def calculate_percent_dalys_averted_by_cause( + dalys_by_cause, + baseline_scenario, +): + + pct_diff = ( + -100.0 + * find_difference_relative_to_comparison_series_dataframe( + dalys_by_cause, + comparison=baseline_scenario, + scaled=True, + ) + ) + + summarized = summarize(pct_diff) + results = {} + scenario_names = (summarized.columns.get_level_values(0).unique()) + + for scenario in scenario_names: + results[scenario] = pd.DataFrame({ + "mean": summarized[(scenario, "mean")], + "lower": summarized[(scenario, "lower")], + "upper": summarized[(scenario, "upper")], + }) + + return results + + +# Calculate % DALYs averted by age group +def calculate_percent_dalys_averted_by_age_group( + dalys_by_age_group, + baseline_scenario, +): + """ + Run-level comparison first, + then summarize. + """ + + pct_diff = ( + -100.0 + * find_difference_relative_to_comparison_series_dataframe( + dalys_by_age_group, + comparison=baseline_scenario, + scaled=True, + ) + ) + + summarized = summarize(pct_diff) + results = {} + scenario_names = (summarized.columns.get_level_values(0).unique()) + + for scenario in scenario_names: + results[scenario] = pd.DataFrame({ + "mean": summarized[(scenario, "mean")], + "lower": summarized[(scenario, "lower")], + "upper": summarized[(scenario, "upper")], + }) + + return results + + +# Calculate % deaths averted by age group +def calculate_percent_deaths_averted_by_age_group( + deaths_by_age_group, + baseline_scenario, +): + + pct_diff = ( + -100.0 + * find_difference_relative_to_comparison_series_dataframe( + deaths_by_age_group, + comparison=baseline_scenario, + scaled=True, + ) + ) + + summarized = summarize(pct_diff) + results = {} + scenario_names = (summarized.columns.get_level_values(0).unique()) + + for scenario in scenario_names: + results[scenario] = pd.DataFrame({ + "mean": summarized[(scenario, "mean")], + "lower": summarized[(scenario, "lower")], + "upper": summarized[(scenario, "upper")], + }) + + return results + + +def plot_percent_dalys_averted_comparison(default_df, improved_df,): + fig, axes = plt.subplots(ncols=2, figsize=(12, 6), sharey=True,) + + panel_data = [ + (axes[0], default_df, "Default Healthsystem"), + (axes[1], improved_df, "Improved Healthsystem"), + ] + + for ax, df, title in panel_data: + ordered_scenarios = [ + s for s in df.index + if "More Nurses" in s + ] + [ + s for s in df.index + if "Fewer Nurses" in s + ] + + labels = ["More nurses" if "More Nurses" in s else "Fewer nurses" for s in ordered_scenarios] + + means = df.loc[ordered_scenarios, "mean"].values + lowers = df.loc[ordered_scenarios, "lower"].values + uppers = df.loc[ordered_scenarios, "upper"].values + + yerr = np.vstack([ + means - lowers, + uppers - means, + ]) + + colors = ["steelblue" if "More Nurses" in s else "indianred" for s in ordered_scenarios] + + ax.bar(labels, means, yerr=yerr, capsize=6, color=colors, width=0.55,) + ax.axhline(0, color="black", linewidth=1,) + ax.set_title(title) + ax.grid(axis="y",alpha=0.3,) + + axes[0].set_ylabel( + "% DALYs averted compared to Baseline\n" + "(total between 2027 and 2034)" + ) + + fig.suptitle( + "% DALYs averted relative to baseline (2027–2034)", + fontsize=14, + ) + fig.tight_layout() + return fig, axes + + +def plot_percent_deaths_averted_comparison(default_df,improved_df,): + fig, axes = plt.subplots(ncols=2, figsize=(12, 6), sharey=True,) + + panel_data = [ + (axes[0], default_df, "Default Healthsystem"), + (axes[1], improved_df, "Improved Healthsystem"), + ] + + for ax, df, title in panel_data: + ordered_scenarios = [ + s for s in df.index + if "More Nurses" in s + ] + [ + s for s in df.index + if "Fewer Nurses" in s + ] + + labels = ["More nurses" if "More Nurses" in s else "Fewer nurses" for s in ordered_scenarios] + + means = df.loc[ordered_scenarios, "mean"].values + lowers = df.loc[ordered_scenarios, "lower"].values + uppers = df.loc[ordered_scenarios, "upper"].values + + yerr = np.vstack([ + means - lowers, + uppers - means, + ]) + + colors = ["steelblue" if "More Nurses" in s else "indianred" for s in ordered_scenarios] + + ax.bar(labels, means, yerr=yerr, capsize=6, color=colors, width=0.55,) + ax.axhline(0, color="black", linewidth=1,) + ax.set_title(title) + ax.grid(axis="y", alpha=0.3,) + + axes[0].set_ylabel( + "% deaths averted compared to Baseline\n" + "(total between 2027 and 2034)" + ) + + fig.suptitle( + "% deaths averted relative to baseline (2027–2034)", + fontsize=14, + ) + fig.tight_layout() + return fig, axes + + +# Plot % DALYs averted by cause +def plot_percent_dalys_averted_by_cause(default_df, improved_df, top_n=30): + + # Extracting scenario dataframes + default_more = default_df[ + "More Nurses / Default Healthsystem Function" + ] + + default_fewer = default_df[ + "Fewer Nurses / Default Healthsystem Function" + ] + + improved_more = improved_df[ + "More Nurses / Improved Healthsystem Function" + ] + + improved_fewer = improved_df[ + "Fewer Nurses / Improved Healthsystem Function" + ] + + # Using sum + # total_dalys = ( + # dalys_by_cause + # .xs(baseline_scenario, level="draw", axis=1) + # .sum(axis=1) + # .sort_values(ascending=False) + # ) + # + # top_causes = total_dalys.head(10).index.tolist() + # + # default_more = default_more.loc[top_causes] + # default_fewer = default_fewer.loc[top_causes] + # + # improved_more = improved_more.loc[top_causes] + # improved_fewer = improved_fewer.loc[top_causes] + # + # # Reverse so largest appears at top + # default_more = default_more.iloc[::-1] + # default_fewer = default_fewer.iloc[::-1] + # + # improved_more = improved_more.iloc[::-1] + # improved_fewer = improved_fewer.iloc[::-1] + + # Top causes for DEFAULT healthsystem + default_top = ( + default_more["mean"] + .abs() + .sort_values(ascending=False) + .head(top_n) + .index + ) + + # default_more = ( + # default_more.loc[default_top] + # .sort_values("mean", ascending=True) + # ) + default_more = default_more.reindex(cause_order) + + # default_fewer = ( + # default_fewer.loc[default_top] + # .reindex(default_more.index) + # ) + default_fewer = default_fewer.reindex(cause_order) + + # Top causes for IMPROVED healthsystem + improved_top = ( + improved_more["mean"] + .abs() + .sort_values(ascending=False) + .head(top_n) + .index + ) + + # improved_more = ( + # improved_more.loc[improved_top] + # .sort_values("mean", ascending=True) + # ) + improved_more = improved_more.reindex(cause_order) + + # improved_fewer = ( + # improved_fewer.loc[improved_top] + # .reindex(improved_more.index) + # ) + improved_fewer = improved_fewer.reindex(cause_order) + + # Plot + fig, axes = plt.subplots(ncols=2, figsize=(14, 10), sharey=True) + + panel_data = [ + ( + axes[0], + default_more, + default_fewer, + "Default Healthsystem", + ), + ( + axes[1], + improved_more, + improved_fewer, + "Improved Healthsystem", + ), + ] + + for ax, more, fewer, title in panel_data: + y = np.arange(len(more)) + ax.barh(y - 0.2, more["mean"], height=0.35, color="steelblue", label="More nurses",) + + ax.barh(y + 0.2, fewer["mean"], height=0.35, color="indianred", label="Fewer nurses",) + + # CI bars: More nurses + ax.errorbar( + more["mean"], + y - 0.2, + xerr=[ + more["mean"] - more["lower"], + more["upper"] - more["mean"], + ], + fmt="none", + capsize=2, + color="black", + alpha=0.5, + ) + + # CI bars: Fewer nurses + ax.errorbar( + fewer["mean"], + y + 0.2, + xerr=[ + fewer["mean"] - fewer["lower"], + fewer["upper"] - fewer["mean"], + ], + fmt="none", + capsize=2, + color="black", + alpha=0.5, + ) + + ax.axvline(0, color="black", linewidth=1) + ax.set_yticks(y) + ax.set_yticklabels(more.index) + ax.set_xlabel("% DALYs averted") + ax.set_title(title) + ax.grid(axis="x", alpha=0.3) + + handles, labels = axes[0].get_legend_handles_labels() + + fig.legend(handles, labels, loc="lower center", ncol=2, bbox_to_anchor=(0.5, -0.02),) + + fig.suptitle( + "% DALYs averted by causes on national level\n(2027–2034)" + ) + fig.tight_layout() + return fig, axes + + +# Plot % deaths averted by cause +def plot_percent_deaths_averted_by_cause(default_df, improved_df, top_n=30): + + # Extracting scenario dataframes + default_more = default_df[ + "More Nurses / Default Healthsystem Function" + ] + + default_fewer = default_df[ + "Fewer Nurses / Default Healthsystem Function" + ] + + improved_more = improved_df[ + "More Nurses / Improved Healthsystem Function" + ] + + improved_fewer = improved_df[ + "Fewer Nurses / Improved Healthsystem Function" + ] + + # Top causes for DEFAULT healthsystem + default_top = ( + default_more["mean"] + .abs() + .sort_values(ascending=False) + .head(top_n) + .index + ) + + # default_more = ( + # default_more.loc[default_top] + # .sort_values("mean", ascending=True) + # ) + default_more = default_more.reindex(death_order) + + # default_fewer = ( + # default_fewer.loc[default_top] + # .reindex(default_more.index) + # ) + default_fewer = default_fewer.reindex(death_order) + + # Top causes for IMPROVED healthsystem + improved_top = ( + improved_more["mean"] + .abs() + .sort_values(ascending=False) + .head(top_n) + .index + ) + + # improved_more = ( + # improved_more.loc[improved_top] + # .sort_values("mean", ascending=True) + # ) + improved_more = improved_more.reindex(death_order) + + # improved_fewer = ( + # improved_fewer.loc[improved_top] + # .reindex(improved_more.index) + # ) + improved_fewer = improved_fewer.reindex(death_order) + + # Plot + fig, axes = plt.subplots(ncols=2, figsize=(14, 10), sharey=True) + + panel_data = [ + ( + axes[0], + default_more, + default_fewer, + "Default Healthsystem", + ), + ( + axes[1], + improved_more, + improved_fewer, + "Improved Healthsystem", + ), + ] + + for ax, more, fewer, title in panel_data: + y = np.arange(len(more)) + + ax.barh(y - 0.2, more["mean"], height=0.35, color="steelblue", label="More nurses",) + ax.barh(y + 0.2, fewer["mean"], height=0.35, color="indianred", label="Fewer nurses",) + + ax.errorbar( + more["mean"], + y - 0.2, + xerr=[ + more["mean"] - more["lower"], + more["upper"] - more["mean"], + ], + fmt="none", + capsize=2, + color="black", + alpha=0.5, + ) + + ax.errorbar( + fewer["mean"], + y + 0.2, + xerr=[ + fewer["mean"] - fewer["lower"], + fewer["upper"] - fewer["mean"], + ], + fmt="none", + capsize=2, + color="black", + alpha=0.5 + ) + + ax.axvline(0, color="black", linewidth=1) + ax.set_yticks(y) + ax.set_yticklabels(more.index) + ax.set_xlabel("% deaths averted") + ax.set_title(title) + ax.grid(axis="x", alpha=0.3) + + handles, labels = axes[0].get_legend_handles_labels() + + fig.legend(handles, labels, loc="lower center", ncol=2, bbox_to_anchor=(0.5, -0.02),) + fig.suptitle( + "% deaths averted by causes on national level\n(2027–2034)" + ) + fig.tight_layout() + return fig, axes + + +# Plot % DALYs averted by age group +def plot_percent_dalys_averted_by_age_group(default_df,improved_df,): + + default_more = default_df[ + "More Nurses / Default Healthsystem Function" + ] + + default_fewer = default_df[ + "Fewer Nurses / Default Healthsystem Function" + ] + + improved_more = improved_df[ + "More Nurses / Improved Healthsystem Function" + ] + + improved_fewer = improved_df[ + "Fewer Nurses / Improved Healthsystem Function" + ] + + # Ordering age groups + age_order = [ + "0-4", + "5-9", + "10-14", + "15-19", + "20-24", + "25-29", + "30-34", + "35-39", + "40-44", + "45-49", + "50-54", + "55-59", + "60-64", + "65-69", + "70-74", + "75-79", + "80+", + ] + + for df in [default_more, default_fewer, improved_more, improved_fewer,]: + df = df.reindex(age_order) + + default_more = default_more.reindex(age_order) + default_fewer = default_fewer.reindex(age_order) + + improved_more = improved_more.reindex(age_order) + improved_fewer = improved_fewer.reindex(age_order) + + # Reverse so oldest ages appear at top + default_more = default_more.iloc[::-1] + default_fewer = default_fewer.iloc[::-1] + + improved_more = improved_more.iloc[::-1] + improved_fewer = improved_fewer.iloc[::-1] + + # Plot + fig, axes = plt.subplots(ncols=2, figsize=(14, 8), sharey=True,) + + panel_data = [ + ( + axes[0], + default_more, + default_fewer, + "Default Healthsystem", + ), + ( + axes[1], + improved_more, + improved_fewer, + "Improved Healthsystem", + ), + ] + + for ax, more, fewer, title in panel_data: + y = np.arange(len(more)) + + # More nurses + ax.barh(y - 0.2, more["mean"], height=0.35, color="steelblue", label="More Nurses",) + + # Fewer nurses + ax.barh(y + 0.2, fewer["mean"], height=0.35, color="indianred", label="Fewer Nurses",) + + # CI for More Nurses + ax.errorbar( + more["mean"], + y - 0.2, + xerr=[ + more["mean"] - more["lower"], + more["upper"] - more["mean"], + ], + fmt="none", + color="black", + capsize=3, + ) + + # CI for Fewer Nurses + ax.errorbar( + fewer["mean"], + y + 0.2, + xerr=[ + fewer["mean"] - fewer["lower"], + fewer["upper"] - fewer["mean"], + ], + fmt="none", + color="black", + capsize=3, + ) + + ax.axvline(0, color="black") + ax.set_yticks(y) + ax.set_yticklabels(more.index) + ax.set_xlabel("% DALYs averted") + ax.set_title(title) + ax.grid(axis="x", alpha=0.3) + + fig.suptitle( + "% DALYs averted by age group on national level\n(2027–2034)" + ) + + # Add legend + handles, labels = axes[0].get_legend_handles_labels() + + fig.legend(handles, labels, loc="lower center", ncol=2, frameon=False,) + fig.tight_layout() + return fig, axes + + +# Plot % deaths averted by age group +def plot_percent_deaths_averted_by_age_group(default_df,improved_df,): + + default_more = default_df[ + "More Nurses / Default Healthsystem Function" + ] + + default_fewer = default_df[ + "Fewer Nurses / Default Healthsystem Function" + ] + + improved_more = improved_df[ + "More Nurses / Improved Healthsystem Function" + ] + + improved_fewer = improved_df[ + "Fewer Nurses / Improved Healthsystem Function" + ] + + age_order = [ + "0-4", + "5-9", + "10-14", + "15-19", + "20-24", + "25-29", + "30-34", + "35-39", + "40-44", + "45-49", + "50-54", + "55-59", + "60-64", + "65-69", + "70-74", + "75-79", + "80+", + ] + + default_more = default_more.reindex(age_order) + default_fewer = default_fewer.reindex(age_order) + + improved_more = improved_more.reindex(age_order) + improved_fewer = improved_fewer.reindex(age_order) + + # Reverse so oldest age groups appear at top + default_more = default_more.iloc[::-1] + default_fewer = default_fewer.iloc[::-1] + + improved_more = improved_more.iloc[::-1] + improved_fewer = improved_fewer.iloc[::-1] + + fig, axes = plt.subplots(ncols=2, figsize=(14, 8), sharey=True,) + + panel_data = [ + ( + axes[0], + default_more, + default_fewer, + "Default Healthsystem", + ), + ( + axes[1], + improved_more, + improved_fewer, + "Improved Healthsystem", + ), + ] + + for ax, more, fewer, title in panel_data: + y = np.arange(len(more)) + + ax.barh(y - 0.2, more["mean"], height=0.35, color="steelblue", label="More nurses",) + ax.barh(y + 0.2, fewer["mean"], height=0.35, color="indianred", label="Fewer nurses",) + + # More nurses CI + ax.errorbar( + more["mean"], + y - 0.2, + xerr=[ + more["mean"] - more["lower"], + more["upper"] - more["mean"], + ], + fmt="none", + capsize=4, + color="black", + ) + + # Fewer nurses CI + ax.errorbar( + fewer["mean"], + y + 0.2, + xerr=[ + fewer["mean"] - fewer["lower"], + fewer["upper"] - fewer["mean"], + ], + fmt="none", + capsize=4, + color="black", + ) + + ax.axvline(0, color="black", linewidth=1) + ax.set_yticks(y) + ax.set_yticklabels(more.index) + ax.set_xlabel("% deaths averted") + ax.set_title(title) + ax.grid(axis="x", alpha=0.3) + + handles, labels = axes[0].get_legend_handles_labels() + fig.legend(handles, labels, loc="lower center", ncol=2, bbox_to_anchor=(0.5, -0.02),) + fig.suptitle( + "% deaths averted by age group on national level\n(2027–2034)" + ) + fig.tight_layout() + return fig, axes + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + "Analyse DALYs/Deaths across nurse staffing scenarios" + ) + parser.add_argument( + "--scenario-outputs-folder", + type=Path, + required=True, + help="Path to folder containing scenario outputs", + ) + parser.add_argument( + "--show-figures", + action="store_true", + help="Whether to interactively show figures", + ) + parser.add_argument( + "--save-figures", + action="store_true", + help="Whether to save figures to results folder", + ) + args = parser.parse_args() + + # Use command-line folder + results_folder = args.scenario_outputs_folder + + # Optional: load logs + log = load_pickled_dataframes(results_folder) + + # Getting scenario names from scenario class + param_names = tuple(StaffingScenario()._scenarios.keys()) + + # Scnarios to keep (Default Healthsystem Function only) + default_hs_scenarios = [ + "Baseline Nurses / Default Healthsystem Function", + "Fewer Nurses / Default Healthsystem Function", + "More Nurses / Default Healthsystem Function", + ] + + baseline_scenario = "Baseline Nurses / Default Healthsystem Function" + + improved_hs_scenarios = [ + "Baseline Nurses / Improved Healthsystem Function", + "Fewer Nurses / Improved Healthsystem Function", + "More Nurses / Improved Healthsystem Function", + ] + + baseline_improved_scenario = ("Baseline Nurses / Improved Healthsystem Function") + + # Extract annual DALYs + annual_dalys = extract_annual_dalys(results_folder).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names, + ) + + # Summarize across runs + # Filter to Default Healthsystem Function scenarios only + summarized_annual_dalys = summarize(annual_dalys) + + # Filter to Default Healthsystem Function scenarios only + summarized_annual_dalys_default = summarized_annual_dalys.loc[ + :, + summarized_annual_dalys.columns.get_level_values(0).isin( + default_hs_scenarios + ), + ] + + # Filter to Improved Healthsystem Function scenarios only + summarized_annual_dalys_improved = summarized_annual_dalys.loc[ + :, + summarized_annual_dalys.columns.get_level_values(0).isin( + improved_hs_scenarios + ), + ] + + # Plot 1: Annual DALYs over time + fig_1, ax_1 = plot_annual_dalys(summarized_annual_dalys_default) + + # Plot 2: Percent DALYs averted relative to baseline (2027–2034) + percent_dalys_averted = calculate_percent_dalys_averted( + annual_dalys.loc[ + :, + annual_dalys.columns.get_level_values(0).isin(default_hs_scenarios) + ], + baseline_scenario=baseline_scenario, + comparison_years=range(2027, 2035), + ) + + percent_dalys_averted_improved = calculate_percent_dalys_averted( + annual_dalys.loc[ + :, + annual_dalys.columns.get_level_values(0).isin(improved_hs_scenarios) + ], + baseline_scenario=baseline_improved_scenario, + comparison_years=range(2027, 2035), + ) + + fig_2, ax_2 = plot_percent_dalys_averted_comparison( + percent_dalys_averted, + percent_dalys_averted_improved, + ) + + # Sensitivity analysis: DALYs under Improved Healthsystem Function + fig_5, ax_5 = plot_annual_dalys( + summarized_annual_dalys_improved + ) + + # Extract annual deaths + annual_deaths = extract_annual_deaths(results_folder).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names, + ) + + summarized_annual_deaths = summarize(annual_deaths) + + # Default Healthsystem Function deaths + summarized_annual_deaths_default = summarized_annual_deaths.loc[ + :, + summarized_annual_deaths.columns.get_level_values(0).isin( + default_hs_scenarios + ), + ] + + # Improved Healthsystem Function deaths + summarized_annual_deaths_improved = summarized_annual_deaths.loc[ + :, + summarized_annual_deaths.columns.get_level_values(0).isin( + improved_hs_scenarios + ), + ] + + # Plot annual deaths + fig_3, ax_3 = plot_annual_deaths( + summarized_annual_deaths_default + ) + + # Plot % deaths averted + percent_deaths_averted = calculate_percent_deaths_averted( + annual_deaths.loc[ + :, + annual_deaths.columns.get_level_values(0).isin(default_hs_scenarios) + ], + baseline_scenario=baseline_scenario, + comparison_years=range(2027, 2035), + ) + + percent_deaths_averted_improved = calculate_percent_deaths_averted( + annual_deaths.loc[ + :, + annual_deaths.columns.get_level_values(0).isin(improved_hs_scenarios) + ], + baseline_scenario=baseline_improved_scenario, + comparison_years=range(2027, 2035), + ) + + fig_4, ax_4 = plot_percent_deaths_averted_comparison( + percent_deaths_averted, + percent_deaths_averted_improved, + ) + + # Sensitivity analysis: deaths under Improved Healthsystem Function + fig_7, ax_7 = plot_annual_deaths( + summarized_annual_deaths_improved + ) + + # Extract deaths by cause + deaths_by_cause = extract_deaths_by_cause(results_folder).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names, + ) + + # check that total deaths equal to sum of deaths by cause + total_deaths = annual_deaths.loc[ + (annual_deaths.index >= 2027) & (annual_deaths.index <= 2034) + ].sum(axis=0) + total_deaths_cause = deaths_by_cause.sum(axis=0) + assert (total_deaths.index == total_deaths_cause.index).all() + assert (abs(total_deaths.values - total_deaths_cause.values) < 1e-7).all() + + # find the descending order of causes in terms of total deaths in baseline scenario + mean_deaths_by_cause = deaths_by_cause.groupby(axis=1, level="draw").mean().sort_values( + by="Baseline Nurses / Default Healthsystem Function", + ascending=True, + ) + death_order = mean_deaths_by_cause.index.tolist() + + deaths_by_cause_default = ( + deaths_by_cause.loc[ + :, + deaths_by_cause.columns + .get_level_values(0) + .isin(default_hs_scenarios) + ] + ) + + percent_deaths_by_cause_default = ( + calculate_percent_deaths_averted_by_cause( + deaths_by_cause_default, + baseline_scenario=baseline_scenario, + ) + ) + + deaths_by_cause_improved = ( + deaths_by_cause.loc[ + :, + deaths_by_cause.columns + .get_level_values(0) + .isin(improved_hs_scenarios) + ] + ) + + percent_deaths_by_cause_improved = ( + calculate_percent_deaths_averted_by_cause( + deaths_by_cause_improved, + baseline_scenario=baseline_improved_scenario, + ) + ) + + fig_10, ax_10 = plot_percent_deaths_averted_by_cause( + percent_deaths_by_cause_default, + percent_deaths_by_cause_improved, + top_n=30, + ) + + # Extract deaths by age group + deaths_by_age_group = extract_deaths_by_age_group( + results_folder + ).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names, + ) + + # check that total deaths equal to sum of deaths by age group + total_deaths_age = deaths_by_age_group.sum(axis=0) + assert (total_deaths.index == total_deaths_age.index).all() + assert (abs(total_deaths.values - total_deaths_age.values) < 1e-7).all() + + deaths_by_age_group_default = ( + deaths_by_age_group.loc[ + :, + deaths_by_age_group.columns + .get_level_values(0) + .isin(default_hs_scenarios) + ] + ) + + percent_deaths_by_age_default = ( + calculate_percent_deaths_averted_by_age_group( + deaths_by_age_group_default, + baseline_scenario=baseline_scenario, + ) + ) + + deaths_by_age_group_improved = ( + deaths_by_age_group.loc[ + :, + deaths_by_age_group.columns + .get_level_values(0) + .isin(improved_hs_scenarios) + ] + ) + + percent_deaths_by_age_improved = ( + calculate_percent_deaths_averted_by_age_group( + deaths_by_age_group_improved, + baseline_scenario=baseline_improved_scenario, + ) + ) + + fig_12, ax_12 = plot_percent_deaths_averted_by_age_group( + percent_deaths_by_age_default, + percent_deaths_by_age_improved, + ) + + # Extract DALYs by cause + dalys_by_cause = extract_dalys_by_cause(results_folder).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names, + ) + + # check that total dalys equal to sum of dalys by cause + total_dalys = annual_dalys.loc[ + (annual_dalys.index >= 2027) & (annual_dalys.index <= 2034) + ].sum(axis=0) + total_dalys_cause = dalys_by_cause.sum(axis=0) + assert (total_dalys.index == total_dalys_cause.index).all() + assert (abs(total_dalys.values - total_dalys_cause.values) < 1e-7).all() + + # find the descending order of causes in terms of total dalys in baseline scenario + mean_dalys_by_cause = dalys_by_cause.groupby(axis=1, level="draw").mean().sort_values( + by="Baseline Nurses / Default Healthsystem Function", + ascending=True, + ) + cause_order = mean_dalys_by_cause.index.tolist() + + # Default Healthsystem + dalys_by_cause_default = ( + dalys_by_cause.loc[ + :, + dalys_by_cause.columns + .get_level_values(0) + .isin(default_hs_scenarios) + ] + ) + + percent_by_cause_default = ( + calculate_percent_dalys_averted_by_cause( + dalys_by_cause_default, + baseline_scenario=baseline_scenario, + ) + ) + + # Improved Healthsystem + dalys_by_cause_improved = ( + dalys_by_cause.loc[ + :, + dalys_by_cause.columns + .get_level_values(0) + .isin(improved_hs_scenarios) + ] + ) + + percent_by_cause_improved = ( + calculate_percent_dalys_averted_by_cause( + dalys_by_cause_improved, + baseline_scenario=baseline_improved_scenario, + ) + ) + + fig_9, ax_9 = plot_percent_dalys_averted_by_cause( + percent_by_cause_default, + percent_by_cause_improved, + top_n=30, + ) + + # Extract DALYs by age group + dalys_by_age_group = extract_dalys_by_age_group( + results_folder + ).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names, + ) + + # check that total dalys equal to sum of dalys by age groups + total_dalys_age = dalys_by_age_group.sum(axis=0) + assert (total_dalys.index == total_dalys_age.index).all() + assert (abs(total_dalys.values - total_dalys_age.values) < 1e-7).all() + + dalys_by_age_group_default = ( + dalys_by_age_group.loc[ + :, + dalys_by_age_group.columns + .get_level_values(0) + .isin(default_hs_scenarios) + ] + ) + + percent_dalys_by_age_default = ( + calculate_percent_dalys_averted_by_age_group( + dalys_by_age_group_default, + baseline_scenario=baseline_scenario, + ) + ) + + dalys_by_age_group_improved = ( + dalys_by_age_group.loc[ + :, + dalys_by_age_group.columns + .get_level_values(0) + .isin(improved_hs_scenarios) + ] + ) + + percent_dalys_by_age_improved = ( + calculate_percent_dalys_averted_by_age_group( + dalys_by_age_group_improved, + baseline_scenario=baseline_improved_scenario, + ) + ) + + fig_11, ax_11 = plot_percent_dalys_averted_by_age_group( + percent_dalys_by_age_default, + percent_dalys_by_age_improved, + ) + + # Showing figures + if args.show_figures: + plt.show() + + # Saving figures + if args.save_figures: + fig_1.savefig( + results_folder / "annual_dalys_across_scenarios.pdf", + bbox_inches="tight", + ) + + fig_2.savefig( + results_folder / "percent_dalys_averted_vs_baseline_2027_2034_comparison.pdf", + bbox_inches="tight", + ) + + fig_3.savefig( + results_folder / "annual_deaths_across_scenarios.pdf", + bbox_inches="tight", + ) + + fig_4.savefig( + results_folder / "percent_deaths_averted_vs_baseline_2027_2034_comparison.pdf", + bbox_inches="tight", + ) + + # Sensitivity-analysis DALY figures + fig_5.savefig( + results_folder / + "annual_dalys_across_scenarios_improved_healthsystem.pdf", + bbox_inches="tight", + ) + + # fig_6.savefig( + # results_folder / + # "percent_dalys_averted_vs_baseline_2027_2034_improved_healthsystem.pdf", + # bbox_inches="tight", + # ) + + # Sensitivity-analysis death figures + fig_7.savefig( + results_folder / + "annual_deaths_across_scenarios_improved_healthsystem.pdf", + bbox_inches="tight", + ) + + # fig_8.savefig( + # results_folder / + # "percent_deaths_averted_vs_baseline_2027_2034_improved_healthsystem.pdf", + # bbox_inches="tight", + # ) + + fig_9.savefig( + results_folder / + "percent_dalys_averted_by_cause_national_level.pdf", + bbox_inches="tight", + ) + + fig_10.savefig( + results_folder / + "percent_deaths_averted_by_cause_national_level.pdf", + bbox_inches="tight", + ) + + fig_11.savefig( + results_folder / + "percent_dalys_averted_by_age_group_national_level.pdf", + bbox_inches="tight", + ) + + fig_12.savefig( + results_folder / + "percent_deaths_averted_by_age_group_national_level.pdf", + bbox_inches="tight", + ) diff --git a/src/scripts/nurses_analyses/analysis_nurses_scenario_detailed.py b/src/scripts/nurses_analyses/analysis_nurses_scenario_detailed.py new file mode 100644 index 0000000000..ae638f8a4a --- /dev/null +++ b/src/scripts/nurses_analyses/analysis_nurses_scenario_detailed.py @@ -0,0 +1,532 @@ +"""This file uses the results of the results of running `nurse_analyses/nurses_scenario_analyses.py` to make plots of +nurse counts over time and appointments over time for each scenario/draw name from 2010 to 2034.""" + +import argparse +from pathlib import Path +from typing import Tuple, Dict +import pickle + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from scripts.nurses_analyses.nurses_scenario_analyses import StaffingScenario +from tlo import Date +from tlo.analysis.utils import ( + load_pickled_dataframes, + summarize, +) + + +# Rename draw numbers to scenario names +def set_param_names_as_column_index_level_0(_df, param_names): + """Set column index level 0 (draw numbers) to scenario names.""" + ordered_param_names = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ + ordered_param_names.get(col) + for col in _df.columns.levels[0] + ] + _df.columns = _df.columns.set_levels(names_of_cols_level0, level=0) + return _df + + +def load_data_manually(results_folder: Path) -> Dict: + """ + Manually load data from the folder structure we observed. + Folder structure: draw_folder/run_folder/pickle_files + """ + data_by_draw = {} + + # Find all draw folders (0, 1, 2, 3, 4, 5) + draw_folders = [d for d in results_folder.iterdir() if d.is_dir() and d.name.isdigit()] + draw_folders.sort(key=lambda x: int(x.name)) + + print(f"\nFound {len(draw_folders)} draw folders: {[d.name for d in draw_folders]}") + + for draw_folder in draw_folders: + draw_num = int(draw_folder.name) + data_by_draw[draw_num] = {} + + # Find run folders (0, 1) + run_folders = [r for r in draw_folder.iterdir() if r.is_dir() and r.name.isdigit()] + run_folders.sort(key=lambda x: int(x.name)) + + print(f"\nDraw {draw_num} - Found {len(run_folders)} run folders: {[r.name for r in run_folders]}") + + for run_folder in run_folders: + run_num = int(run_folder.name) + + # Load all pickle files in this run folder + pickle_files = list(run_folder.glob("*.pickle")) + + run_data = {} + for pickle_file in pickle_files: + try: + with open(pickle_file, 'rb') as f: + data = pickle.load(f) + + # Store by module name (filename without extension) + module_name = pickle_file.stem + run_data[module_name] = data + print(f" Loaded {module_name} from run {run_num}") + + except Exception as e: + print(f" Error loading {pickle_file.name}: {e}") + + data_by_draw[draw_num][run_num] = run_data + + return data_by_draw + + +def extract_nurse_counts_from_run(run_data: Dict, target_years=range(2010, 2035)) -> pd.Series: + """ + Extract nurse counts from a single run's data. + Looking for the right data source - probably not number_of_hcw_staff directly. + """ + # Look for healthsystem summary data + for module_name, data in run_data.items(): + if 'healthsystem.summary' in module_name: + if isinstance(data, dict): + print(f" Examining {module_name} - keys: {list(data.keys())}") + + # First, let's check what DataFrames are available + for key in data.keys(): + if isinstance(data[key], pd.DataFrame): + df = data[key] + print(f" DataFrame '{key}' has columns: {list(df.columns)}") + + # Check if this might have nurse count data + if 'date' in df.columns: + # Look for columns that might contain nurse counts + for col in df.columns: + if 'Nursing' in str(col) or 'Midwifery' in str(col) or 'staff' in str(col).lower(): + print(f" Found potential nurse column: {col}") + + # If we find a promising DataFrame, try to extract + if 'Capacity' in key or 'staff' in key.lower(): + df['year'] = pd.to_datetime(df['date']).dt.year + df_filtered = df[df['year'].isin(target_years)] + + if not df_filtered.empty: + # Look for nursing columns + nursing_cols = [col for col in df_filtered.columns + if 'Nursing' in str(col) or 'Midwifery' in str(col)] + + if nursing_cols: + # Sum across all nursing columns + result = df_filtered.groupby('year')[nursing_cols].sum().sum(axis=1) + print(f" Found nursing columns: {nursing_cols}") + print(f" Sample values: {result.head()}") + return result + + # If no nursing columns, look for staff columns + staff_cols = [col for col in df_filtered.columns + if 'staff' in str(col).lower() or 'count' in str(col).lower()] + + if staff_cols and len(staff_cols) > 0: + # Try to get the first staff column + result = df_filtered.groupby('year')[staff_cols[0]].mean() + print(f" Using staff column: {staff_cols[0]}") + print(f" Sample values: {result.head()}") + return result + + return pd.Series(dtype=float) + + +def extract_appointments_from_run(run_data: Dict, target_years=range(2010, 2035)) -> pd.Series: + """ + Extract appointments from a single run's data. + """ + # Look for healthsystem summary data + for module_name, data in run_data.items(): + if 'healthsystem.summary' in module_name: + if isinstance(data, dict): + # Look for HSI_Event data + for key in ['HSI_Event', 'HSI_Event_non_blank_appt_footprint']: + if key in data: + df = data[key] + if isinstance(df, pd.DataFrame): + if 'date' in df.columns and 'Number_By_Appt_Type_Code' in df.columns: + df['year'] = pd.to_datetime(df['date']).dt.year + + # Filter to target years + df_filtered = df[df['year'].isin(target_years)] + + if not df_filtered.empty: + # Expand appointment counts + appts_expanded = df_filtered['Number_By_Appt_Type_Code'].apply(pd.Series) + + # Group by year and sum + appts_expanded['year'] = df_filtered['year'].values + yearly = appts_expanded.groupby('year').sum() + + return yearly.sum(axis=1) + + return pd.Series(dtype=float) + + +def process_all_draws(data_by_draw: Dict, target_years=range(2010, 2035)): + """ + Process all draws to get nurse counts and appointments. + Returns DataFrames with draws as columns and years as index. + """ + nurse_data = {} + appt_data = {} + + for draw_num, run_data_dict in data_by_draw.items(): + draw_nurse_series = [] + draw_appt_series = [] + + for run_num, run_data in run_data_dict.items(): + print(f"\n Processing Draw {draw_num}, Run {run_num}") + + # Extract nurse counts for this run + nurse_series = extract_nurse_counts_from_run(run_data, target_years) + if not nurse_series.empty: + draw_nurse_series.append(nurse_series) + print(f" ✓ Found nurse data with years: {list(nurse_series.index)[:5]}...") + print(f" ✓ Sample values: {list(nurse_series.values)[:5]}...") + + # Extract appointments for this run + appt_series = extract_appointments_from_run(run_data, target_years) + if not appt_series.empty: + draw_appt_series.append(appt_series) + print(f" ✓ Found appointment data with years: {list(appt_series.index)[:5]}...") + + # Average across runs for this draw + if draw_nurse_series: + # Convert list of Series to DataFrame and compute mean + nurse_df = pd.DataFrame(draw_nurse_series) + nurse_data[draw_num] = nurse_df.mean() + print(f" Draw {draw_num}: Averaged nurse data from {len(draw_nurse_series)} runs") + + if draw_appt_series: + appt_df = pd.DataFrame(draw_appt_series) + appt_data[draw_num] = appt_df.mean() + print(f" Draw {draw_num}: Averaged appointment data from {len(draw_appt_series)} runs") + + # Convert to DataFrames with draws as columns + nurse_df = pd.DataFrame(nurse_data) if nurse_data else pd.DataFrame() + appt_df = pd.DataFrame(appt_data) if appt_data else pd.DataFrame() + + return nurse_df, appt_df + + +# ============================================================================= +# PLOTTING FUNCTIONS +# ============================================================================= + +def plot_nurse_counts(nurse_df, param_names, output_folder, target_period_str): + """ + Plot nurse counts over time for all scenarios. + """ + if nurse_df.empty: + print("No nurse count data to plot") + return None, None + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Define colors and line styles + colors = plt.cm.tab10(np.linspace(0, 1, 10)) + + # Use different markers to distinguish overlapping lines + markers = ['o', 's', '^', 'D', 'v', '<'] + + plots_made = False + + for draw_idx, scenario in enumerate(param_names): + if draw_idx in nurse_df.columns: + series = nurse_df[draw_idx] + + if series is not None and not series.empty: + plots_made = True + + # Determine label, color, and line style based on scenario name + if 'Baseline' in scenario: + level = 'Baseline' + elif 'Fewer' in scenario: + level = 'Fewer' + elif 'More' in scenario: + level = 'More' + else: + level = 'Unknown' + + if 'Default' in scenario: + hs_type = 'Default' + color = colors[0] # Blue for Default + else: # Improved + hs_type = 'Improved' + color = colors[1] # Orange for Improved + + # Line styles based on nurse level + if level == 'Baseline': + linestyle = '-' + elif level == 'Fewer': + linestyle = '--' + elif level == 'More': + linestyle = ':' + else: + linestyle = '-' + + label = f"{level} - {hs_type}" + + # Use different markers for each draw to see if lines are overlapping + marker = markers[draw_idx % len(markers)] + + ax.plot( + series.index, + series.values, + label=label, + color=color, + linestyle=linestyle, + marker=marker, + markersize=6, + markevery=3, + linewidth=2 + ) + print(f" ✓ Plotted Draw {draw_idx}: {label}") + + if not plots_made: + print(" No data to plot") + plt.close(fig) + return None, None + + ax.set_xlabel('Year', fontsize=12) + ax.set_ylabel('Number of Nurses', fontsize=12) + ax.set_title(f'Nurse Counts Over Time by Scenario ({target_period_str})', fontsize=14) + ax.grid(True, alpha=0.3) + ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), fontsize=10) + + # Set x-ticks + all_years = [] + for col in nurse_df.columns: + all_years.extend(nurse_df[col].index) + + if all_years: + all_years = sorted(set(all_years)) + tick_years = all_years[::2] if len(all_years) > 10 else all_years + ax.set_xticks(tick_years) + ax.set_xticklabels(tick_years, rotation=45) + + fig.tight_layout() + + # Save figures + fig.savefig(output_folder / "nurse_counts_over_time.pdf", bbox_inches='tight') + fig.savefig(output_folder / "nurse_counts_over_time.png", bbox_inches='tight', dpi=300) + + return fig, ax + + +def plot_appointments(appt_df, param_names, output_folder, target_period_str): + """ + Plot appointments over time for all scenarios. + """ + if appt_df.empty: + print("No appointment data to plot") + return None, None + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Define colors and line styles + colors = plt.cm.tab10(np.linspace(0, 1, 10)) + + # Use different markers to distinguish overlapping lines + markers = ['o', 's', '^', 'D', 'v', '<'] + + plots_made = False + + for draw_idx, scenario in enumerate(param_names): + if draw_idx in appt_df.columns: + series = appt_df[draw_idx] + + if series is not None and not series.empty: + plots_made = True + + # Determine label, color, and line style based on scenario name + if 'Baseline' in scenario: + level = 'Baseline' + elif 'Fewer' in scenario: + level = 'Fewer' + elif 'More' in scenario: + level = 'More' + else: + level = 'Unknown' + + if 'Default' in scenario: + hs_type = 'Default' + color = colors[0] # Blue for Default + else: # Improved + hs_type = 'Improved' + color = colors[1] # Orange for Improved + + # Line styles based on nurse level + if level == 'Baseline': + linestyle = '-' + elif level == 'Fewer': + linestyle = '--' + elif level == 'More': + linestyle = ':' + else: + linestyle = '-' + + label = f"{level} - {hs_type}" + + # Use different markers for each draw + marker = markers[draw_idx % len(markers)] + + # Convert to millions for plotting + values_millions = series.values / 1_000_000 + + ax.plot( + series.index, + values_millions, + label=label, + color=color, + linestyle=linestyle, + marker=marker, + markersize=6, + markevery=3, + linewidth=2 + ) + print(f" ✓ Plotted Draw {draw_idx}: {label}") + + if not plots_made: + print(" No data to plot") + plt.close(fig) + return None, None + + ax.set_xlabel('Year', fontsize=12) + ax.set_ylabel('Appointments (millions)', fontsize=12) + ax.set_title(f'Total Appointments Delivered Over Time by Scenario ({target_period_str})', fontsize=14) + ax.grid(True, alpha=0.3) + ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), fontsize=10) + + # Set x-ticks + all_years = [] + for col in appt_df.columns: + all_years.extend(appt_df[col].index) + + if all_years: + all_years = sorted(set(all_years)) + tick_years = all_years[::2] if len(all_years) > 10 else all_years + ax.set_xticks(tick_years) + ax.set_xticklabels(tick_years, rotation=45) + + fig.tight_layout() + + # Save figures + fig.savefig(output_folder / "appointments_over_time.pdf", bbox_inches='tight') + fig.savefig(output_folder / "appointments_over_time.png", bbox_inches='tight', dpi=300) + + return fig, ax + + +# ============================================================================= +# MAIN +# ============================================================================= + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + "Plot nurse counts and appointments from nurses scenario" + ) + parser.add_argument( + "--scenario-outputs-folder", + type=Path, + required=True, + help="Path to folder containing scenario outputs", + ) + parser.add_argument( + "--show-figures", + action="store_true", + help="Whether to interactively show figures", + ) + parser.add_argument( + "--save-figures", + action="store_true", + help="Whether to save figures", + ) + args = parser.parse_args() + + results_folder = args.scenario_outputs_folder + + print(f"\n{'='*60}") + print(f"Loading results from: {results_folder}") + print(f"{'='*60}") + + # Get scenario names + param_names = tuple(StaffingScenario()._scenarios.keys()) + print(f"\nFound {len(param_names)} scenarios:") + for i, name in enumerate(param_names): + print(f" {i}: {name}") + + # Create output folder + output_folder = results_folder / "analysis_output" + output_folder.mkdir(exist_ok=True) + + # Define target period + target_years = range(2010, 2035) + target_period_str = "2010-2034" + + # Manually load all data + print(f"\n{'='*60}") + print("MANUALLY LOADING DATA FROM FOLDER STRUCTURE") + print(f"{'='*60}") + + data_by_draw = load_data_manually(results_folder) + + # Process all draws to extract nurse counts and appointments + print(f"\n{'='*60}") + print("EXTRACTING NURSE COUNTS AND APPOINTMENTS") + print(f"{'='*60}") + + nurse_df, appt_df = process_all_draws(data_by_draw, target_years) + + # Print summary of extracted data + print(f"\n{'='*60}") + print("EXTRACTION SUMMARY") + print(f"{'='*60}") + + if not nurse_df.empty: + print(f"\n✓ Nurse count data shape: {nurse_df.shape}") + print(f"Draws with nurse data: {list(nurse_df.columns)}") + for col in nurse_df.columns: + print(f" Draw {col}: years {list(nurse_df[col].index)[:5]}...") + print(f" Values: {list(nurse_df[col].values)[:5]}...") + else: + print("\n✗ No nurse count data found") + + if not appt_df.empty: + print(f"\n✓ Appointment data shape: {appt_df.shape}") + print(f"Draws with appointment data: {list(appt_df.columns)}") + for col in appt_df.columns: + print(f" Draw {col}: years {list(appt_df[col].index)[:5]}...") + else: + print("\n✗ No appointment data found") + + # Generate plots + print(f"\n{'='*60}") + print("GENERATING PLOTS") + print(f"{'='*60}") + + if not nurse_df.empty: + print("\nPlotting nurse counts...") + fig1, ax1 = plot_nurse_counts(nurse_df, param_names, output_folder, target_period_str) + if fig1 is not None: + print(f"✓ Nurse counts plot saved to {output_folder}/nurse_counts_over_time.png") + else: + print("\n✗ Cannot plot nurse counts - no data available") + + if not appt_df.empty: + print("\nPlotting appointments...") + fig2, ax2 = plot_appointments(appt_df, param_names, output_folder, target_period_str) + if fig2 is not None: + print(f"✓ Appointments plot saved to {output_folder}/appointments_over_time.png") + else: + print("\n✗ Cannot plot appointments - no data available") + + print(f"\n{'='*60}") + print("Analysis complete!") + print(f"{'='*60}") + + if args.show_figures: + plt.show() diff --git a/src/scripts/nurses_analyses/analysis_staff_num_more.py b/src/scripts/nurses_analyses/analysis_staff_num_more.py new file mode 100644 index 0000000000..c6a5feb66b --- /dev/null +++ b/src/scripts/nurses_analyses/analysis_staff_num_more.py @@ -0,0 +1,267 @@ +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +from scripts.nurses_analyses.nurses_scenario_analyses import StaffingScenario + +from tlo.analysis.utils import ( + extract_results, + summarize, +) + + +# ----------------------------------------------------------------------------- +# Rename draw numbers to scenario names +# ----------------------------------------------------------------------------- +def set_param_names_as_column_index_level_0(_df, param_names): + + ordered_param_names = { + i: x for i, x in enumerate(param_names) + } + + names_of_cols_level0 = [ + ordered_param_names.get(col) + for col in _df.columns.levels[0] + ] + + _df.columns = _df.columns.set_levels( + names_of_cols_level0, + level=0 + ) + + return _df + + +# ----------------------------------------------------------------------------- +# Extract annual staffing counts +# ----------------------------------------------------------------------------- +def get_yearly_hr_count(df): + + if 'GenericClinic' not in df.columns: + return None + + df['year'] = df['date'].dt.year + + # Expand dictionary + staff_df = df['GenericClinic'].apply(pd.Series) + + # Keep cadre names only + staff_df.columns = [ + c.split('Officer_')[-1] + for c in staff_df.columns + ] + + # Sum facilities within cadre + staff_df = staff_df.groupby(level=0, axis=1).sum() + + # Add year + staff_df['year'] = df['year'] + + # Annual totals + staff_df = staff_df.groupby('year').sum() + + # Scale population + # POP_SCALE = 145.39609 + # staff_df = staff_df * POP_SCALE + + return staff_df.stack() + + +def extract_staff_counts(results_folder): + + return extract_results( + results_folder, + module="tlo.methods.healthsystem.summary", + key="number_of_hcw_staff", + custom_generate_series=get_yearly_hr_count, + do_scaling=False, + ) + + +# ----------------------------------------------------------------------------- +# Prepare plotting dataframe +# ----------------------------------------------------------------------------- +def prepare_staffing_totals(summary_df): + + scenarios = ( + summary_df.columns + .get_level_values(0) + .unique() + ) + + results = {} + + for scenario in scenarios: + + mean_df = summary_df[(scenario, "mean")].unstack() + + # Nurses + nurses = mean_df["Nursing_and_Midwifery"] + + # Other cadres + other_cadres = mean_df.drop( + columns=["Nursing_and_Midwifery"], + errors="ignore" + ).sum(axis=1) + + results[scenario] = pd.DataFrame({ + "Nurses": nurses, + "Other cadres": other_cadres, + }) + + return results + + +# ----------------------------------------------------------------------------- +# Plot staffing counts +# ----------------------------------------------------------------------------- +def plot_staffing_counts( + staffing_results, + scenarios, + title, +): + + fig, ax = plt.subplots(figsize=(10, 6)) + + label_map = { + "Baseline Nurses": "Baseline nurses", + "Fewer Nurses": "Fewer nurses", + "More Nurses": "More nurses", + } + + # Plot nurse scenarios + for scenario in scenarios: + + df = staffing_results[scenario] + + label = None + + for key in label_map: + if key in scenario: + label = f"Nurses, {label_map[key]}" + + ax.plot( + df.index, + df["Nurses"], + linewidth=2, + label=label, + ) + + # Plot other cadres once + other_df = staffing_results[scenarios[0]] + + ax.plot( + other_df.index, + other_df["Other cadres"], + linewidth=2.5, + linestyle="--", + color="black", + label="Other cadres total", + ) + + ax.set_xlabel("Year") + ax.set_ylabel("Annual staff count") + + ax.set_title(title) + + ax.legend() + + ax.grid(alpha=0.3) + + fig.tight_layout() + + return fig, ax + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--scenario-outputs-folder", + type=Path, + required=True, + ) + + parser.add_argument( + "--show-figures", + action="store_true", + ) + + parser.add_argument( + "--save-figures", + action="store_true", + ) + + args = parser.parse_args() + + results_folder = args.scenario_outputs_folder + + # Scenario names + param_names = tuple( + StaffingScenario()._scenarios.keys() + ) + + # Extract + staff_counts = extract_staff_counts( + results_folder + ).pipe( + set_param_names_as_column_index_level_0, + param_names=param_names, + ) + + # Summarize + summarized_staff_counts = summarize( + staff_counts + ) + + # Prepare totals + staffing_results = prepare_staffing_totals( + summarized_staff_counts + ) + + # Scenario groups + default_hs_scenarios = [ + "Baseline Nurses / Default Healthsystem Function", + "Fewer Nurses / Default Healthsystem Function", + "More Nurses / Default Healthsystem Function", + ] + + improved_hs_scenarios = [ + "Baseline Nurses / Improved Healthsystem Function", + "Fewer Nurses / Improved Healthsystem Function", + "More Nurses / Improved Healthsystem Function", + ] + + # Plot default HS + fig1, ax1 = plot_staffing_counts( + staffing_results, + default_hs_scenarios, + title="Annual staffing count\nDefault Healthsystem", + ) + + # Plot improved HS + fig2, ax2 = plot_staffing_counts( + staffing_results, + improved_hs_scenarios, + title="Annual staffing count\nImproved Healthsystem", + ) + + if args.save_figures: + fig1.savefig( + results_folder / "annual_staffing_default_hs.pdf", + bbox_inches="tight" + ) + + fig2.savefig( + results_folder / "annual_staffing_improved_hs.pdf", + bbox_inches="tight" + ) + + if args.show_figures: + plt.show() diff --git a/src/scripts/nurses_analyses/analysis_staff_num_more_districts.py b/src/scripts/nurses_analyses/analysis_staff_num_more_districts.py new file mode 100644 index 0000000000..d308279c6e --- /dev/null +++ b/src/scripts/nurses_analyses/analysis_staff_num_more_districts.py @@ -0,0 +1,536 @@ +import argparse +from collections import Counter, defaultdict +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np +import pandas as pd +import squarify +from matplotlib import pyplot as plt + +from tlo import Date +from tlo.analysis.utils import ( + COARSE_APPT_TYPE_TO_COLOR_MAP, + SHORT_TREATMENT_ID_TO_COLOR_MAP, + _standardize_short_treatment_id, + # DON'T import bin_hsi_event_details from utils + compute_mean_across_runs, + extract_results, + get_coarse_appt_type, + get_color_short_treatment_id, + load_pickled_dataframes, + order_of_short_treatment_ids, + plot_stacked_bar_chart, + squarify_neat, + summarize, + unflatten_flattened_multi_index_in_logging, +) +import re +from scripts.nurses_analyses.nurses_scenario_analyses import StaffingScenario + +# Declare period for which the results will be generated (defined inclusively) +TARGET_PERIOD = (Date(2010, 1, 1), Date(2034, 12, 31)) + + +def drop_outside_period(_df): + """Return a dataframe which only includes for which the date is within the limits defined by TARGET_PERIOD""" + return _df.drop(index=_df.index[~_df['date'].between(*TARGET_PERIOD)]) + + +def figure4_hr_use_overall(results_folder: Path, output_folder: Path, resourcefilepath: Path): + """ 'Figure 4': The level of usage of the HealthSystem HR Resources """ + + make_graph_file_name = lambda stub: output_folder / f"Fig4_{stub}.png" # noqa: E731 + + def get_share_of_time_for_hw_in_each_facility_by_short_treatment_id(_df): + + _df = drop_outside_period(_df) + _df = _df.set_index("date") + + nurse_cols = [ + c for c in _df.columns + if "Officer_Nursing_and_Midwifery" in c + ] + + if len(nurse_cols) == 0: + return None + + nurse_df = _df[nurse_cols] + + # Mean usage across all nurse facilities + nurse_df = nurse_df.copy() + nurse_df.loc[:, "All"] = nurse_df.mean(axis=1) + # nurse_df["All"] = nurse_df.mean(axis=1) + + return nurse_df.resample("M").mean().stack() + + def get_share_of_time_used_for_each_officer_at_each_level(_df): + + _df = drop_outside_period(_df) + _df = _df.set_index("date") + + # Columns look like: + # clinic=GenericClinic|facID_and_officer=FacilityID_0_Officer_Nursing_and_Midwifery + + officer_cols = [ + c for c in _df.columns if "FacilityID_" in c and "Officer_" in c + ] + + if len(officer_cols) == 0: + return None + + officer_df = _df[officer_cols].copy() + + # Load Master Facility List + mfl = pd.read_csv( + Path("./resources/healthsystem/organisation/ResourceFile_Master_Facilities_List.csv") + ).set_index("Facility_ID") + + results = [] + + for col in officer_cols: + + col_string = str(col) + + # Extract facility ID + fac_match = re.search(r'FacilityID_(\d+)', col_string) + if fac_match is None: + continue + fid = int(fac_match.group(1)) + + # Extract cadre + officer_match = re.search(r'Officer_(.*)', col_string) + if officer_match is None: + continue + cadre = officer_match.group(1) + + # Get facility level + if fid not in mfl.index: + continue + + level = mfl.loc[fid, "Facility_Level"] + level = "2" if level == "1b" else level + + # Compute mean usage + mean_val = officer_df[col].mean() + + results.append((cadre, level, mean_val)) + + if len(results) == 0: + return None + + result_df = pd.DataFrame(results, columns=["Cadre", "Facility_Level", "Usage"]) + + return result_df.groupby(["Cadre", "Facility_Level"])["Usage"].mean() + + capacity_by_facility = summarize( + extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Capacity_By_FacID_and_Officer', + custom_generate_series=get_share_of_time_for_hw_in_each_facility_by_short_treatment_id, + do_scaling=False + ), + only_mean=True, + collapse_columns=True + ) + + capacity_by_officer = summarize( + extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Capacity_By_FacID_and_Officer', + custom_generate_series=get_share_of_time_used_for_each_officer_at_each_level, + do_scaling=False + ), + only_mean=True, + collapse_columns=True + ) + + # Find the levels of each facility + mfl = pd.read_csv( + resourcefilepath / 'healthsystem' / 'organisation' / 'ResourceFile_Master_Facilities_List.csv' + ).set_index('Facility_ID') + + def find_level_for_facility(col_name): + # Skip aggregated column + if col_name == "All": + return None + + match = re.search(r'FacilityID_(\d+)', str(col_name)) + + if match is None: + return None + + fid = int(match.group(1)) + + level = mfl.loc[fid, "Facility_Level"] + + return "2" if level == "1b" else level + + # def find_level_for_facility(col_tuple): + # # Extract the text part + # col_string = col_tuple[2] + # + # # Extract facility ID number + # match = re.search(r'FacilityID_(\d+)', col_string) + # fid = int(match.group(1)) + # + # level = mfl.loc[fid, "Facility_Level"] + # return "2" if level == "1b" else level + # def find_level_for_facility(id): + # return mfl.loc[id].Facility_Level if mfl.loc[id].Facility_Level != '1b' else '2' + # def find_level_for_facility(fid): + # level = mfl.loc[fid, "Facility_Level"] + # return "2" if level == "1b" else level + + color_for_level = {'0': 'blue', '1a': 'yellow', '1b': 'green', '2': 'grey', '3': 'orange', '4': 'black', + '5': 'white'} + + fig, ax = plt.subplots() + name_of_plot = 'Usage of Healthcare Worker Time By Month' + capacity_unstacked = capacity_by_facility.unstack() + for i in capacity_unstacked.columns: + + level = find_level_for_facility(i) + + if level is None: + continue + + h1, = ax.plot( + capacity_unstacked[i].index, + capacity_unstacked[i].values, + color=color_for_level[level], + linewidth=0.5, + label=f'Facility_Level {level}' + ) + # for i in capacity_unstacked.columns: + # if i != 'All': + # level = find_level_for_facility(i) + # h1, = ax.plot(capacity_unstacked[i].index, capacity_unstacked[i].values, + # color=color_for_level[level], linewidth=0.5, label=f'Facility_Level {level}') + + if 'All' in capacity_unstacked.columns: + h2, = ax.plot( + capacity_unstacked['All'].index, + capacity_unstacked['All'].values, + color='red', + linewidth=1.5 + ) + ax.legend([h1, h2], ['Each Facility', 'All Facilities']) + else: + ax.legend([h1], ['Each Facility']) + + ax.set_title(name_of_plot) + ax.set_xlabel('Month') + ax.set_ylabel('Fraction of all time used\n(Average for the month)') + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + fig.tight_layout() + fig.savefig(make_graph_file_name(name_of_plot.replace(' ', '_'))) + plt.close(fig) + + fig, ax = plt.subplots() + name_of_plot = 'Usage of Healthcare Worker Time (Average)' + capacity_unstacked_average = capacity_by_facility.unstack().mean() + # levels = [find_level_for_facility(i) if i != 'All' else 'All' for i in capacity_unstacked_average.index] + xpos_for_level = dict(zip((color_for_level.keys()), range(len(color_for_level)))) + xpos_for_level.update({'1b': 2, '2': 2, '3': 3, '4': 4, '5': 5}) + for id, val in capacity_unstacked_average.items(): + if id != 'All': + _level = find_level_for_facility(id) + + # Skip if facility level could not be determined + if _level is None: + continue + + if _level != '5': + xpos = xpos_for_level[_level] + scatter = (np.random.rand() - 0.5) * 0.25 + h1, = ax.plot(xpos + scatter, val * 100, color=color_for_level[_level], + marker='.', markersize=15, label='Each Facility', linestyle='none') + if 'All' in capacity_unstacked_average.index: + h2 = ax.axhline( + y=capacity_unstacked_average['All'] * 100, + color='red', + linestyle='--', + label='Average' + ) + ax.set_title(name_of_plot) + ax.set_xlabel('Facility_Level') + ax.set_xticks(list(xpos_for_level.values())) + ax.set_xticklabels(xpos_for_level.keys()) + ax.set_ylabel('Percent of Time Available That is Used\n') + ax.legend(handles=[h1, h2]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + fig.tight_layout() + fig.savefig(make_graph_file_name(name_of_plot.replace(' ', '_'))) + plt.close(fig) + + fig, ax = plt.subplots() + name_of_plot = 'Usage of Healthcare Worker Time by Cadre and Facility_Level' + (100.0 * capacity_by_officer.unstack()).T.plot.bar(ax=ax) + ax.legend() + ax.set_xlabel('Facility_Level') + ax.set_ylabel('Percent of time that is used') + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.set_title(name_of_plot) + fig.tight_layout() + fig.savefig(make_graph_file_name(name_of_plot.replace(' ', '_'))) + plt.close(fig) + + +def get_yearly_hr_count(_df): + + if 'GenericClinic' not in _df.columns: + return None + + years = _df['date'].dt.year.rename("year") + + # Expand facility dictionary + staff_df = _df['GenericClinic'].apply(pd.Series) + + # Extract facility IDs + facility_ids = [ + int(c.split("FacilityID_")[1].split("_")[0]) + for c in staff_df.columns + ] + + # Extract cadre names + cadres = [ + c.split("Officer_")[-1] + for c in staff_df.columns + ] + + # Load Master Facility List + mfl = pd.read_csv( + Path("./resources/healthsystem/organisation/ResourceFile_Master_Facilities_List.csv") + ).set_index("Facility_ID") + + # Add district info for facilities at levels 3+ that have nan district info, + # to avoid these facilities being dropped + for fid in {128, 129, 130, 131, 132}: + mfl.loc[fid, "District"] = mfl.loc[fid, "Facility_Name"] + + # Map facilities to districts + districts = [ + mfl.loc[fid, "District"] if fid in mfl.index else "Unknown" + for fid in facility_ids + ] + + # Create MultiIndex columns + staff_df.columns = pd.MultiIndex.from_arrays( + [districts, cadres], + names=["District", "Cadre"] + ) + + # Sum yearly + staff_df = staff_df.groupby(years).sum() + + # Sum facilities within district/cadre + staff_df = staff_df.T.groupby(level=[0, 1]).sum().T + + # POP_SCALE = 145.39609 + # staff_df = staff_df * POP_SCALE + + # Convert columns to index + return staff_df.stack([0, 1]) + + +def extract_staff_counts(results_folder): + return extract_results( + results_folder, + module="tlo.methods.healthsystem.summary", + key="number_of_hcw_staff", + custom_generate_series=get_yearly_hr_count, + do_scaling=False + ) + + +def set_param_names_as_column_index_level_0(_df, param_names): + """Set column index level 0 (draw numbers) to scenario names.""" + ordered_param_names = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ + ordered_param_names.get(col) + for col in _df.columns.levels[0] + ] + _df.columns = _df.columns.set_levels(names_of_cols_level0, level=0) + return _df + + +def plot_staff_counts_by_cadre_across_scenarios_by_district( + staff_counts_summary, + output_folder +): + + scenario_names = staff_counts_summary.columns.get_level_values(0).unique() + + districts = ( + staff_counts_summary.index + .get_level_values("District") + .unique() + ) + + cadres = ( + staff_counts_summary.index + .get_level_values("Cadre") + .unique() + ) + + for district in districts: + + district_df = staff_counts_summary.xs( + district, + level="District" + ) + + for cadre in cadres: + + if cadre not in district_df.index.get_level_values("Cadre"): + continue + + fig, ax = plt.subplots() + + for scenario in scenario_names: + + central = district_df[(scenario, "mean")].xs( + cadre, + level="Cadre" + ) + + lower = district_df[(scenario, "lower")].xs( + cadre, + level="Cadre" + ) + + upper = district_df[(scenario, "upper")].xs( + cadre, + level="Cadre" + ) + + years = central.index + + ax.plot( + years, + central.values, + label=scenario + ) + + ax.fill_between( + years, + np.maximum(lower.values, 0), + upper.values, + alpha=0.25 + ) + + ax.set_title( + f"{cadre} Staff Counts Across Scenarios ({district})" + ) + + ax.set_xlabel("Year") + ax.set_ylabel("Average Number of Health Workers") + + ax.legend() + + fig.tight_layout() + + fig.savefig( + output_folder / + f"{district}_{cadre}_staff_counts_across_scenarios.png" + ) + + plt.close(fig) + + +def apply(results_folder: Path, output_folder: Path, resourcefilepath: Path = None): + """Description of the usage of healthcare system resources.""" + + # figure2_appointments_used( + # results_folder=results_folder, output_folder=output_folder, resourcefilepath=resourcefilepath + # ) + log = load_pickled_dataframes(results_folder, 0, 0) + print(log.keys()) + + print(log['tlo.methods.healthsystem.summary'].keys()) + + # STEP 1: extract staff counts + staff_counts = extract_staff_counts(results_folder) + + # STEP 2: rename draws to scenario names + param_names = tuple(StaffingScenario()._scenarios.keys()) + + staff_counts = staff_counts.pipe( + set_param_names_as_column_index_level_0, + param_names=param_names + ) + + # STEP 3: summarize runs + print(type(staff_counts)) + print(staff_counts.head()) + staff_counts_summary = summarize(staff_counts) + + print("\n=== Staff counts summary ===") + print(staff_counts_summary.index.names) + + print("\n=== Staff counts from 2025–2034 ===") + + # Select years 2025–2034 + years_to_check = range(2025, 2035) + + export_df = staff_counts_summary.reset_index() + + # Filter the years + export_df = export_df[export_df["year"].isin(years_to_check)] + + # Save to Excel + export_path = output_folder / "debug_staff_counts_2025_2034.xlsx" + export_df.to_excel(export_path) + + print(f"Staff counts exported to: {export_path}") + + # STEP 4: plot + plot_staff_counts_by_cadre_across_scenarios_by_district( + staff_counts_summary, + output_folder + ) + + figure4_hr_use_overall( + results_folder=results_folder, output_folder=output_folder, resourcefilepath=resourcefilepath + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--scenario-outputs-folder", + type=Path, + required=True, + help="Path to folder containing scenario outputs", + ) + parser.add_argument( + "--show-figures", + action="store_true", + help="Whether to interactively show figures", + ) + parser.add_argument( + "--save-figures", + action="store_true", + help="Whether to save figures", + ) + args = parser.parse_args() + + # Use the command-line argument instead of hardcoded path + results_folder = args.scenario_outputs_folder + # results_folder = Path( + # './outputs/wamulwafu@kuhes.ac.mw/nurses_scenario_outputs-2026-02-09T110530Z' + # ) + + apply( + results_folder=results_folder, # or directly: args.scenario_outputs_folder + output_folder=results_folder, + resourcefilepath=Path('./resources') + ) diff --git a/src/scripts/nurses_analyses/analysis_time_and_appts.py b/src/scripts/nurses_analyses/analysis_time_and_appts.py new file mode 100644 index 0000000000..a7c22e69a9 --- /dev/null +++ b/src/scripts/nurses_analyses/analysis_time_and_appts.py @@ -0,0 +1,333 @@ +import argparse +from collections import Counter, defaultdict +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np +import pandas as pd +import squarify +from matplotlib import pyplot as plt + +from tlo import Date +from tlo.analysis.utils import ( + COARSE_APPT_TYPE_TO_COLOR_MAP, + SHORT_TREATMENT_ID_TO_COLOR_MAP, + _standardize_short_treatment_id, + # DON'T import bin_hsi_event_details from utils + compute_mean_across_runs, + extract_results, + get_coarse_appt_type, + get_color_short_treatment_id, + # load_pickled_dataframes, + order_of_short_treatment_ids, + plot_stacked_bar_chart, + squarify_neat, + summarize, + unflatten_flattened_multi_index_in_logging, +) +import re + +# Declare period for which the results will be generated (defined inclusively) +TARGET_PERIOD = (Date(2010, 1, 1), Date(2034, 12, 31)) + + +def drop_outside_period(_df): + """Return a dataframe which only includes for which the date is within the limits defined by TARGET_PERIOD""" + return _df.drop(index=_df.index[~_df['date'].between(*TARGET_PERIOD)]) + + +def figure4_hr_use_overall(results_folder: Path, output_folder: Path, resourcefilepath: Path): + """ 'Figure 4': The level of usage of the HealthSystem HR Resources """ + + make_graph_file_name = lambda stub: output_folder / f"Fig4_{stub}.png" # noqa: E731 + + def get_share_of_time_for_hw_in_each_facility_by_short_treatment_id(_df): + + _df = drop_outside_period(_df) + _df = _df.set_index("date") + + nurse_cols = [ + c for c in _df.columns + if "Officer_Nursing_and_Midwifery" in c + ] + + if len(nurse_cols) == 0: + return None + + nurse_df = _df[nurse_cols] + + # Mean usage across all nurse facilities + nurse_df = nurse_df.copy() + nurse_df.loc[:, "All"] = nurse_df.mean(axis=1) + # nurse_df["All"] = nurse_df.mean(axis=1) + + return nurse_df.resample("M").mean().stack() + + def get_share_of_time_used_for_each_officer_at_each_level(_df): + + _df = drop_outside_period(_df) + _df = _df.set_index("date") + + # Columns look like: + # clinic=GenericClinic|facID_and_officer=FacilityID_0_Officer_Nursing_and_Midwifery + + officer_cols = [ + c for c in _df.columns if "FacilityID_" in c and "Officer_" in c + ] + + if len(officer_cols) == 0: + return None + + officer_df = _df[officer_cols].copy() + + # Load Master Facility List + mfl = pd.read_csv( + Path("./resources/healthsystem/organisation/ResourceFile_Master_Facilities_List.csv") + ).set_index("Facility_ID") + + results = [] + + for col in officer_cols: + + col_string = str(col) + + # Extract facility ID + fac_match = re.search(r'FacilityID_(\d+)', col_string) + if fac_match is None: + continue + fid = int(fac_match.group(1)) + + # Extract cadre + officer_match = re.search(r'Officer_(.*)', col_string) + if officer_match is None: + continue + cadre = officer_match.group(1) + + # Get facility level + if fid not in mfl.index: + continue + + level = mfl.loc[fid, "Facility_Level"] + level = "2" if level == "1b" else level + + # Compute mean usage + mean_val = officer_df[col].mean() + + results.append((cadre, level, mean_val)) + + if len(results) == 0: + return None + + result_df = pd.DataFrame(results, columns=["Cadre", "Facility_Level", "Usage"]) + + return result_df.groupby(["Cadre", "Facility_Level"])["Usage"].mean() + + capacity_by_facility = summarize( + extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Capacity_By_FacID_and_Officer', + custom_generate_series=get_share_of_time_for_hw_in_each_facility_by_short_treatment_id, + do_scaling=False + ), + only_mean=True, + collapse_columns=True + ) + + capacity_by_officer = summarize( + extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Capacity_By_FacID_and_Officer', + custom_generate_series=get_share_of_time_used_for_each_officer_at_each_level, + do_scaling=False + ), + only_mean=True, + collapse_columns=True + ) + + # Find the levels of each facility + mfl = pd.read_csv( + resourcefilepath / 'healthsystem' / 'organisation' / 'ResourceFile_Master_Facilities_List.csv' + ).set_index('Facility_ID') + + def find_level_for_facility(col_name): + # Skip aggregated column + if col_name == "All": + return None + + match = re.search(r'FacilityID_(\d+)', str(col_name)) + + if match is None: + return None + + fid = int(match.group(1)) + + level = mfl.loc[fid, "Facility_Level"] + + return "2" if level == "1b" else level + + # def find_level_for_facility(col_tuple): + # # Extract the text part + # col_string = col_tuple[2] + # + # # Extract facility ID number + # match = re.search(r'FacilityID_(\d+)', col_string) + # fid = int(match.group(1)) + # + # level = mfl.loc[fid, "Facility_Level"] + # return "2" if level == "1b" else level + # def find_level_for_facility(id): + # return mfl.loc[id].Facility_Level if mfl.loc[id].Facility_Level != '1b' else '2' + # def find_level_for_facility(fid): + # level = mfl.loc[fid, "Facility_Level"] + # return "2" if level == "1b" else level + + color_for_level = {'0': 'blue', '1a': 'yellow', '1b': 'green', '2': 'grey', '3': 'orange', '4': 'black', + '5': 'white'} + + fig, ax = plt.subplots() + name_of_plot = 'Usage of Healthcare Worker Time By Month' + capacity_unstacked = capacity_by_facility.unstack() + for i in capacity_unstacked.columns: + + level = find_level_for_facility(i) + + if level is None: + continue + + h1, = ax.plot( + capacity_unstacked[i].index, + capacity_unstacked[i].values, + color=color_for_level[level], + linewidth=0.5, + label=f'Facility_Level {level}' + ) + # for i in capacity_unstacked.columns: + # if i != 'All': + # level = find_level_for_facility(i) + # h1, = ax.plot(capacity_unstacked[i].index, capacity_unstacked[i].values, + # color=color_for_level[level], linewidth=0.5, label=f'Facility_Level {level}') + + if 'All' in capacity_unstacked.columns: + h2, = ax.plot( + capacity_unstacked['All'].index, + capacity_unstacked['All'].values, + color='red', + linewidth=1.5 + ) + ax.legend([h1, h2], ['Each Facility', 'All Facilities']) + else: + ax.legend([h1], ['Each Facility']) + + ax.set_title(name_of_plot) + ax.set_xlabel('Month') + ax.set_ylabel('Fraction of all time used\n(Average for the month)') + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + fig.tight_layout() + fig.savefig(make_graph_file_name(name_of_plot.replace(' ', '_'))) + plt.close(fig) + + fig, ax = plt.subplots() + name_of_plot = 'Usage of Healthcare Worker Time (Average)' + capacity_unstacked_average = capacity_by_facility.unstack().mean() + # levels = [find_level_for_facility(i) if i != 'All' else 'All' for i in capacity_unstacked_average.index] + xpos_for_level = dict(zip((color_for_level.keys()), range(len(color_for_level)))) + xpos_for_level.update({'1b': 2, '2': 2, '3': 3, '4': 4, '5': 5}) + for id, val in capacity_unstacked_average.items(): + if id != 'All': + _level = find_level_for_facility(id) + + # Skip if facility level could not be determined + if _level is None: + continue + + if _level != '5': + xpos = xpos_for_level[_level] + scatter = (np.random.rand() - 0.5) * 0.25 + h1, = ax.plot(xpos + scatter, val * 100, color=color_for_level[_level], + marker='.', markersize=15, label='Each Facility', linestyle='none') + if 'All' in capacity_unstacked_average.index: + h2 = ax.axhline( + y=capacity_unstacked_average['All'] * 100, + color='red', + linestyle='--', + label='Average' + ) + ax.set_title(name_of_plot) + ax.set_xlabel('Facility_Level') + ax.set_xticks(list(xpos_for_level.values())) + ax.set_xticklabels(xpos_for_level.keys()) + ax.set_ylabel('Percent of Time Available That is Used\n') + ax.legend(handles=[h1, h2]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + fig.tight_layout() + fig.savefig(make_graph_file_name(name_of_plot.replace(' ', '_'))) + plt.close(fig) + + fig, ax = plt.subplots() + name_of_plot = 'Usage of Healthcare Worker Time by Cadre and Facility_Level' + (100.0 * capacity_by_officer.unstack()).T.plot.bar(ax=ax) + ax.legend() + ax.set_xlabel('Facility_Level') + ax.set_ylabel('Percent of time that is used') + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.set_title(name_of_plot) + fig.tight_layout() + fig.savefig(make_graph_file_name(name_of_plot.replace(' ', '_'))) + plt.close(fig) + + +def apply(results_folder: Path, output_folder: Path, resourcefilepath: Path = None): + """Description of the usage of healthcare system resources.""" + + # figure2_appointments_used( + # results_folder=results_folder, output_folder=output_folder, resourcefilepath=resourcefilepath + # ) + from tlo.analysis.utils import load_pickled_dataframes + log = load_pickled_dataframes(results_folder, 0, 0) + print(log.keys()) + + print(log['tlo.methods.healthsystem.summary'].keys()) + + figure4_hr_use_overall( + results_folder=results_folder, output_folder=output_folder, resourcefilepath=resourcefilepath + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--scenario-outputs-folder", + type=Path, + required=True, + help="Path to folder containing scenario outputs", + ) + parser.add_argument( + "--show-figures", + action="store_true", + help="Whether to interactively show figures", + ) + parser.add_argument( + "--save-figures", + action="store_true", + help="Whether to save figures", + ) + args = parser.parse_args() + + # Use the command-line argument instead of hardcoded path + results_folder = args.scenario_outputs_folder + # results_folder = Path( + # './outputs/wamulwafu@kuhes.ac.mw/nurses_scenario_outputs-2026-02-09T110530Z' + # ) + + apply( + results_folder=results_folder, # or directly: args.scenario_outputs_folder + output_folder=results_folder, + resourcefilepath=Path('./resources') + ) diff --git a/src/scripts/nurses_analyses/nurses_scenario_analyses.py b/src/scripts/nurses_analyses/nurses_scenario_analyses.py new file mode 100644 index 0000000000..83d7280646 --- /dev/null +++ b/src/scripts/nurses_analyses/nurses_scenario_analyses.py @@ -0,0 +1,234 @@ +""" +This scenario file sets up the scenarios for simulating the effects of nursing staffing levels. + +Run on the batch system using: +``` +tlo batch-submit src/scripts/nurses_analyses/nurses_scenario_analyses.py +``` + +or locally using: +``` +tlo scenario-run src/scripts/nurses_analyses/nurses_scenario_analyses.py + ``` + + + +""" + +from pathlib import Path +from typing import Dict + +from tlo import Date, logging +from tlo.analysis.utils import ( + get_parameters_for_hrh_historical_scaling_and_rescaling_for_mode2, + get_root_path, + mix_scenarios, +) +from tlo.methods.fullmodel import fullmodel +from tlo.methods.scenario_switcher import ImprovedHealthSystemAndCareSeekingScenarioSwitcher +from tlo.scenario import BaseScenario + + +class StaffingScenario(BaseScenario): + def __init__(self): + super().__init__() + self.resources = get_root_path() / "resources" + self.seed = 0 + self.start_date = Date(2010, 1, 1) + self.end_date = Date(2035, 1, 1) + self.pop_size = 100_000 + self._scenarios = self._get_scenarios() + self.number_of_draws = len(self._scenarios) + self.runs_per_draw = 1 + + def log_configuration(self): + return { + 'filename': 'nurses_scenario_outputs', + 'directory': Path('./outputs'), # <- (specified only for local running) + 'custom_levels': { + '*': logging.WARNING, + 'tlo.methods.demography': logging.INFO, + 'tlo.methods.demography.detail': logging.WARNING, + 'tlo.methods.healthburden': logging.INFO, + 'tlo.methods.healthsystem.summary': logging.INFO, + } + } + + def modules(self): + return fullmodel() + [ + ImprovedHealthSystemAndCareSeekingScenarioSwitcher()] + + def draw_parameters(self, draw_number, rng): + if draw_number < self.number_of_draws: + return list(self._scenarios.values())[draw_number] + + def draw_name(self, draw_number) -> str: + """Store scenario name. + (This name can be retrieved by the plotting scripts to make the graphs be labelled nicely). + """ + if draw_number < self.number_of_draws: + return list(self._scenarios.keys())[draw_number] + + @property + def _default_of_all_scenarios(self) -> Dict: + """Base set of parameters is the standard historical scaling and transition into Mode 2.""" + return get_parameters_for_hrh_historical_scaling_and_rescaling_for_mode2() + + @property + def _default_of_all_max_healthsystem_scenarios(self) -> Dict: + """Improved Health System Performance: the same as the default for scenarios, but increases health system + function and healthcare seeking behaviour in 2027""" + return mix_scenarios( + self._default_of_all_scenarios, # <-- start with the same default set of parameters (to avoid repeating them) + { + 'ImprovedHealthSystemAndCareSeekingScenarioSwitcher': { + 'max_healthcare_seeking': [False, True], + 'max_healthsystem_function': [False, True], + 'year_of_switch': 2027, + }, + }, + ) + + def _get_scenarios(self) -> Dict[str, Dict]: + """Return the Dict with values for the parameters that are changed, keyed by a name for the scenario. + """ + year_of_hr_scaling = 2027 + return { + # "Baseline Nurses / Default Healthsystem Function": + # mix_scenarios( + # self._default_of_all_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "default", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "Fewer Nurses / Default Healthsystem Function": + # mix_scenarios( + # self._default_of_all_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "worse_staffing_N", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "More Nurses / Default Healthsystem Function": + # mix_scenarios( + # self._default_of_all_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "establishment_staffing_N", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "Baseline Nurses / Improved Healthsystem Function": + # mix_scenarios( + # self._default_of_all_max_healthsystem_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "default", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "Fewer Nurses / Improved Healthsystem Function": + # mix_scenarios( + # self._default_of_all_max_healthsystem_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "worse_staffing_N", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "More Nurses / Improved Healthsystem Function": + # mix_scenarios( + # self._default_of_all_max_healthsystem_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "establishment_staffing_N", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "More CNP staff / Default Healthsystem Function": + # mix_scenarios( + # self._default_of_all_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "establishment_staffing_CNP", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "More CNP staff / Improved Healthsystem Function": + # mix_scenarios( + # self._default_of_all_max_healthsystem_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_level_and_officer_type_mode': "establishment_staffing_CNP", + # "year_HR_scaling_by_level_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "More Nurses by District / Default Healthsystem Function": + # mix_scenarios( + # self._default_of_all_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_district_and_officer_type_mode': "establishment_by_district_and_N", + # "year_HR_scaling_by_district_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + # + # "More Nurses by District / Improved Healthsystem Function": + # mix_scenarios( + # self._default_of_all_max_healthsystem_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_district_and_officer_type_mode': "establishment_by_district_and_N", + # "year_HR_scaling_by_district_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + + "More CNP staff by District / Default Healthsystem Function": + mix_scenarios( + self._default_of_all_scenarios, + { + "HealthSystem": { + 'HR_scaling_by_district_and_officer_type_mode': "establishment_by_district_and_CNP", + "year_HR_scaling_by_district_and_officer_type": year_of_hr_scaling, + }, + }, + ), + + # "More CNP staff by District / Improved Healthsystem Function": + # mix_scenarios( + # self._default_of_all_max_healthsystem_scenarios, + # { + # "HealthSystem": { + # 'HR_scaling_by_district_and_officer_type_mode': "establishment_by_district_and_CNP", + # "year_HR_scaling_by_district_and_officer_type": year_of_hr_scaling, + # }, + # }, + # ), + } + + +if __name__ == '__main__': + from tlo.cli import scenario_run + + scenario_run([__file__]) diff --git a/src/tlo/methods/demography.py b/src/tlo/methods/demography.py index 2acaad75eb..2c5d998fc4 100644 --- a/src/tlo/methods/demography.py +++ b/src/tlo/methods/demography.py @@ -555,6 +555,8 @@ def do_death(self, individual_id: int, cause: str, originating_module: Module): wealth=person['li_wealth'], date_of_birth=person['date_of_birth'], age_range=person['age_range'], + district_of_residence=person[ + 'district_of_residence'], cause_of_death=cause, ) diff --git a/src/tlo/methods/healthburden.py b/src/tlo/methods/healthburden.py index 54db8bf8fb..cc2ec65fb4 100644 --- a/src/tlo/methods/healthburden.py +++ b/src/tlo/methods/healthburden.py @@ -33,7 +33,7 @@ def __init__(self, name=None): super().__init__(name) # instance variables - self.multi_index_for_age_and_wealth_and_time = None + self.multi_index_for_age_and_wealth_and_time_and_region= None self.years_life_lost = None self.years_life_lost_stacked_time = None self.years_life_lost_stacked_age_and_time = None @@ -89,15 +89,18 @@ def initialise_simulation(self, sim): age_index = self.sim.modules['Demography'].AGE_RANGE_CATEGORIES wealth_index = sim.modules['Lifestyle'].PROPERTIES['li_wealth'].categories year_index = list(range(self.sim.start_date.year, self.sim.end_date.year + 1)) + district_index = sim.modules['Demography'].PROPERTIES['district_of_residence'].categories - self.multi_index_for_age_and_wealth_and_time = pd.MultiIndex.from_product( - [sex_index, age_index, wealth_index, year_index], names=['sex', 'age_range', 'li_wealth', 'year']) + self.multi_index_for_age_and_wealth_and_time_and_region = pd.MultiIndex.from_product( + [sex_index, age_index, wealth_index, district_index, year_index], + names=['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year']) # Create the YLL and YLD storage data-frame (using sex/age_range/year multi-index) - self.years_life_lost = pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time) - self.years_life_lost_stacked_time = pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time) - self.years_life_lost_stacked_age_and_time = pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time) - self.years_lived_with_disability = pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time) + self.years_life_lost = pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time_and_region) + self.years_life_lost_stacked_time = pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time_and_region) + self.years_life_lost_stacked_age_and_time = ( + pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time_and_region)) + self.years_lived_with_disability = pd.DataFrame(index=self.multi_index_for_age_and_wealth_and_time_and_region) # 2) Collect the module that will use this HealthBurden module self.recognised_modules_names = [ @@ -168,6 +171,7 @@ def process_causes_of_dalys(self): 3) Output to the log mappers for causes of disability to the label """ ... + # 1) Collect causes of death and disability that are reported by each disease module, # merging the gbd_causes declared for deaths or disabilities under the same label, @@ -192,13 +196,13 @@ def merge_dicts_of_causes(d1: Dict, d2: Dict) -> Dict: return merged_causes causes_of_death = collect_causes_from_disease_modules( - all_modules=self.sim.modules.values(), - collect='CAUSES_OF_DEATH', - acceptable_causes=self.sim.modules['Demography'].gbd_causes_of_death) + all_modules=self.sim.modules.values(), + collect='CAUSES_OF_DEATH', + acceptable_causes=self.sim.modules['Demography'].gbd_causes_of_death) causes_of_disability = collect_causes_from_disease_modules( - all_modules=self.sim.modules.values(), - collect='CAUSES_OF_DISABILITY', - acceptable_causes=set(self.parameters['gbd_causes_of_disability'])) + all_modules=self.sim.modules.values(), + collect='CAUSES_OF_DISABILITY', + acceptable_causes=set(self.parameters['gbd_causes_of_disability'])) causes_of_death_and_disability = merge_dicts_of_causes( causes_of_death, @@ -299,7 +303,8 @@ def get_daly_weight(self, sequlae_code): return daly_wt - def report_live_years_lost(self, sex=None, wealth=None, date_of_birth=None, age_range=None, cause_of_death=None): + def report_live_years_lost(self, sex=None, wealth=None, date_of_birth=None, + age_range=None, district_of_residence=None, cause_of_death=None): """ Calculate and store the period for which there is 'years of lost life' when someone dies (assuming that the person has died on today's date in the simulation). @@ -313,15 +318,16 @@ def report_live_years_lost(self, sex=None, wealth=None, date_of_birth=None, age_ def _format_for_multi_index(_yll: pd.Series): """Returns pd.Series which is the same as in the argument `_yll` except that the multi-index has been expanded to include sex and li_wealth and rearranged so that it matched the expected multi-index format - (sex/age_range/li_wealth/year).""" - return pd.DataFrame(_yll)\ - .assign(sex=sex, li_wealth=wealth)\ - .set_index(['sex', 'li_wealth'], append=True)\ - .reorder_levels(['sex', 'age_range', 'li_wealth', 'year'])[_yll.name] + (sex/age_range/li_wealth/district_of_residence/year).""" + return pd.DataFrame(_yll) \ + .assign(sex=sex, li_wealth=wealth, district_of_residence=district_of_residence) \ + .set_index(['sex', 'li_wealth', 'district_of_residence'], append=True) \ + .reorder_levels(['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year'])[_yll.name] - assert self.years_life_lost.index.equals(self.multi_index_for_age_and_wealth_and_time) - assert self.years_life_lost_stacked_time.index.equals(self.multi_index_for_age_and_wealth_and_time) - assert self.years_life_lost_stacked_age_and_time.index.equals(self.multi_index_for_age_and_wealth_and_time) + assert self.years_life_lost.index.equals(self.multi_index_for_age_and_wealth_and_time_and_region) + assert self.years_life_lost_stacked_time.index.equals(self.multi_index_for_age_and_wealth_and_time_and_region) + assert (self.years_life_lost_stacked_age_and_time.index. + equals(self.multi_index_for_age_and_wealth_and_time_and_region)) # date from which years of life are lost date_of_death = self.sim.date @@ -345,19 +351,19 @@ def _format_for_multi_index(_yll: pd.Series): end_date=( date_of_birth + pd.DateOffset(years=self.parameters['Age_Limit_For_YLL']) - pd.DateOffset(days=1)), date_of_birth=date_of_birth - ).groupby(level=1).sum()\ - .assign(year=date_of_death.year)\ - .set_index(['year'], append=True)['person_years']\ - .pipe(_format_for_multi_index) + ).groupby(level=1).sum() \ + .assign(year=date_of_death.year) \ + .set_index(['year'], append=True)['person_years'] \ + .pipe(_format_for_multi_index) # Get the years of live lost "stacked by age and time", whereby all the life-years lost up to the age_limit are # ascribed to the age of death and to the year of death. This is computed by collapsing the age-dimension of # `yll_stacked_by_time` onto the age(-range) of death. age_range_to_stack_to = age_range - yll_stacked_by_age_and_time = pd.DataFrame(yll_stacked_by_time.groupby(level=[0, 2, 3]).sum())\ - .assign(age_range=age_range_to_stack_to)\ - .set_index(['age_range'], append=True)['person_years']\ - .reorder_levels(['sex', 'age_range', 'li_wealth', 'year']) + yll_stacked_by_age_and_time = pd.DataFrame(yll_stacked_by_time.groupby(level=[0, 2, 3, 4]).sum()) \ + .assign(age_range=age_range_to_stack_to) \ + .set_index(['age_range'], append=True)['person_years'] \ + .reorder_levels(['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year']) # Add the years-of-life-lost from this death to the overall YLL dataframe keeping track if cause_of_death not in self.years_life_lost.columns: @@ -369,15 +375,17 @@ def _format_for_multi_index(_yll: pd.Series): # Add the life-years-lost from this death to the running total in LifeYearsLost dataframe self.years_life_lost[cause_of_death] = self.years_life_lost[cause_of_death].add( yll, fill_value=0) + self.years_life_lost_stacked_time[cause_of_death] = self.years_life_lost_stacked_time[cause_of_death].add( yll_stacked_by_time, fill_value=0) self.years_life_lost_stacked_age_and_time[cause_of_death] = \ self.years_life_lost_stacked_age_and_time[cause_of_death].add(yll_stacked_by_age_and_time, fill_value=0) # Check that the index of the YLL dataframe is not changed - assert self.years_life_lost.index.equals(self.multi_index_for_age_and_wealth_and_time) - assert self.years_life_lost_stacked_time.index.equals(self.multi_index_for_age_and_wealth_and_time) - assert self.years_life_lost_stacked_age_and_time.index.equals(self.multi_index_for_age_and_wealth_and_time) + assert self.years_life_lost.index.equals(self.multi_index_for_age_and_wealth_and_time_and_region) + assert self.years_life_lost_stacked_time.index.equals(self.multi_index_for_age_and_wealth_and_time_and_region) + assert (self.years_life_lost_stacked_age_and_time.index. + equals(self.multi_index_for_age_and_wealth_and_time_and_region)) def decompose_yll_by_age_and_time(self, start_date, end_date, date_of_birth): """ @@ -414,15 +422,15 @@ def write_to_log(self, year: int): if year in self._years_written_to_log: return # Skip if the year has already been logged. - def summarise_results_for_this_year(df, level=[0, 1]) -> pd.DataFrame: + def summarise_results_for_this_year(df, level=[0, 1, 2, 3]) -> pd.DataFrame: """Return pd.DataFrame that gives the summary of the `df` for the `year` by certain levels in the df's multi-index. The `level` argument gives a list of levels to use in `groupby`: e.g., level=[0,1] gives a summary of sex/age-group; and level=[2] gives a summary only by wealth category.""" - return df.loc[(slice(None), slice(None), slice(None), year)] \ - .groupby(level=level) \ - .sum() \ - .reset_index() \ - .assign(year=year) + return df.loc[(slice(None), slice(None), slice(None), slice(None), year)] \ + .groupby(level=level) \ + .sum() \ + .reset_index() \ + .assign(year=year) def log_df_line_by_line(key, description, df, force_cols=None) -> None: """Log each line of a dataframe to `logger.info`. Each row of the dataframe is one logged entry. @@ -533,10 +541,11 @@ def log_df_line_by_line(key, description, df, force_cols=None) -> None: def check_multi_index(self): """Check that the multi-index of the dataframes are as expected""" - assert self.years_life_lost.index.equals(self.multi_index_for_age_and_wealth_and_time) - assert self.years_life_lost_stacked_time.index.equals(self.multi_index_for_age_and_wealth_and_time) - assert self.years_life_lost_stacked_age_and_time.index.equals(self.multi_index_for_age_and_wealth_and_time) - assert self.years_lived_with_disability.index.equals(self.multi_index_for_age_and_wealth_and_time) + assert self.years_life_lost.index.equals(self.multi_index_for_age_and_wealth_and_time_and_region) + assert self.years_life_lost_stacked_time.index.equals(self.multi_index_for_age_and_wealth_and_time_and_region) + assert (self.years_life_lost_stacked_age_and_time.index. + equals(self.multi_index_for_age_and_wealth_and_time_and_region)) + assert self.years_lived_with_disability.index.equals(self.multi_index_for_age_and_wealth_and_time_and_region) class Get_Current_DALYS(RegularEvent, PopulationScopeEventMixin): @@ -617,20 +626,22 @@ def apply(self, population): # 4) Summarise the results for this month wrt sex/age/wealth # - merge in age/wealth/sex information disease_specific_daly_values_this_month = disease_specific_daly_values_this_month.merge( - df.loc[idx_alive, ['sex', 'li_wealth', 'age_range']], left_index=True, right_index=True, how='left') + df.loc[idx_alive, ['sex', 'li_wealth', 'district_of_residence', 'age_range']], + left_index=True, right_index=True, how='left') # - sum of daly_weight, by sex/age/wealth disability_monthly_summary = pd.DataFrame( - disease_specific_daly_values_this_month.groupby(['sex', 'age_range', 'li_wealth']).sum().fillna(0)) + disease_specific_daly_values_this_month. + groupby(['sex', 'age_range', 'district_of_residence', 'li_wealth']).sum().fillna(0)) # - add the year into the multi-index disability_monthly_summary['year'] = self.sim.date.year disability_monthly_summary.set_index('year', append=True, inplace=True) disability_monthly_summary = disability_monthly_summary.reorder_levels( - ['sex', 'age_range', 'li_wealth', 'year']) + ['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year']) # 5) Add the monthly summary to the overall dataframe for YearsLivedWithDisability - dalys_to_add = disability_monthly_summary.sum().sum() # for checking + dalys_to_add = disability_monthly_summary.sum().sum() # for checking dalys_current = self.module.years_lived_with_disability.sum().sum() # for checking # (Nb. this will add columns that are not otherwise present and add values to columns where they are.) @@ -642,11 +653,12 @@ def apply(self, population): # Merge into a dataframe with the correct multi-index (the multi-index from combine is subtly different) self.module.years_lived_with_disability = \ - pd.DataFrame(index=self.module.multi_index_for_age_and_wealth_and_time)\ - .merge(combined, left_index=True, right_index=True, how='left') + pd.DataFrame(index=self.module.multi_index_for_age_and_wealth_and_time_and_region) \ + .merge(combined, left_index=True, right_index=True, how='left') # Check multi-index is in check and that the addition of DALYS has worked - assert self.module.years_lived_with_disability.index.equals(self.module.multi_index_for_age_and_wealth_and_time) + assert (self.module.years_lived_with_disability.index. + equals(self.module.multi_index_for_age_and_wealth_and_time_and_region)) assert abs(self.module.years_lived_with_disability.sum().sum() - (dalys_to_add + dalys_current)) < 1e-5 self.module.check_multi_index() @@ -660,3 +672,4 @@ def __init__(self, module): def apply(self, population): self.module.write_to_log(year=self.sim.date.year) + diff --git a/src/tlo/methods/healthsystem.py b/src/tlo/methods/healthsystem.py index f8b0f55a03..3d48928cac 100644 --- a/src/tlo/methods/healthsystem.py +++ b/src/tlo/methods/healthsystem.py @@ -308,6 +308,26 @@ class HealthSystem(Module): "(factors informed by survey data); and, `custom` (user can freely set these factors as " "parameters in the analysis).", ), + "HR_scaling_by_district_and_officer_type_table": Parameter( + Types.DICT, + "Factors by which daily capabilities of difference cadres in different districts will be" + "scaled at the start of the year specified by year_HR_scaling_by_district_officer_type to simulate" + "(e.g., through catastrophic event disrupting delivery of services in particular district(s))." + "This is the import of a folder of csv resource files: keys are the file names and values are in the " + "csv files in the format of pd.DataFrames. Additional scenarios can be added by adding " + "csv files to this folder: the value of `HR_scaling_by_district_officer_type_mode` indicates which" + "csv file is used.", + ), + "year_HR_scaling_by_district_and_officer_type": Parameter( + Types.INT, + "Year in which scaling of daily capabilities by district and cadre will take place. " + "(The change happens on 1st January of that year.)", + ), + "HR_scaling_by_district_and_officer_type_mode": Parameter( + Types.STRING, + "Mode of scaling of daily capabilities by district and cadre. This corresponds to the name of the " + "worksheet in the file `ResourceFile_HR_scaling_by_district.xlsx`.", + ), "HR_scaling_by_district_table": Parameter( Types.DICT, "Factors by which daily capabilities in different districts will be" @@ -694,6 +714,22 @@ def read_consumables(filename): f"{self.parameters['HR_scaling_by_level_and_officer_type_mode']}" ) + self.parameters["HR_scaling_by_district_and_officer_type_table"]: Dict = read_csv_files( + path_to_resourcefiles_for_healthsystem + / "human_resources" + / "scaling_capabilities" + / "ResourceFile_HR_scaling_by_district_and_officer_type", + files=None, # all sheets read in + ) + # Ensure the mode of HR scaling to be considered in included in the tables loaded + assert ( + self.parameters["HR_scaling_by_district_and_officer_type_mode"] + in self.parameters["HR_scaling_by_district_and_officer_type_table"] + ), ( + f"Value of `HR_scaling_by_district_and_officer_type_mode` not recognised: " + f"{self.parameters['HR_scaling_by_district_and_officer_type_mode']}" + ) + self.parameters["HR_scaling_by_district_table"]: Dict = read_csv_files( path_to_resourcefiles_for_healthsystem / "human_resources" @@ -904,6 +940,13 @@ def initialise_simulation(self, sim): Date(self.parameters["year_HR_scaling_by_level_and_officer_type"], 1, 1), ) + # Schedule a one-off rescaling of _daily_capabilities broken down by district and officer type. + # This occurs on 1st January of the year specified in the parameters. + sim.schedule_event( + RescaleHRCapabilities_ByDistrictAndOfficerType(self), + Date(self.parameters["year_HR_scaling_by_district_and_officer_type"], 1, 1), + ) + # Schedule a one-off rescaling of _daily_capabilities broken down by district # This occurs on 1st January of the year specified in the parameters. sim.schedule_event( @@ -2983,6 +3026,44 @@ def apply(self, population): officer_type, f"L{level}_factor" ] +class RescaleHRCapabilities_ByDistrictAndOfficerType(Event, PopulationScopeEventMixin): + """This event exists to scale the daily capabilities, with a factor for each pair district and cadre.""" + + def __init__(self, module): + super().__init__(module) + + def apply(self, population): + # Get the set of scaling_factors that are specified by 'HR_scaling_by_district_and_officer_type_mode' + HR_scaling_factor_by_district_and_officer_type = ( + self.module.parameters["HR_scaling_by_district_and_officer_type_table"][ + self.module.parameters["HR_scaling_by_district_and_officer_type_mode"] + ] + .set_index("District") + ) + + pattern = r"FacilityID_(\w+)_Officer_(\w+)" + for clinic, clinic_cl in self.module._daily_capabilities.items(): + for officer in clinic_cl.keys(): + matches = re.match(pattern, officer) + # Extract ID and officer type from + facility_id = int(matches.group(1)) + officer_type = matches.group(2) + # Extract district + if facility_id in range(128): + district = self.module._facility_by_facility_id[facility_id].name.split('_')[-1] + elif facility_id in {128, 129, 130, 131, 132}: + district = self.module._facility_by_facility_id[facility_id].name + else: + district = "N/A" + # Scaling + if ( + (district in HR_scaling_factor_by_district_and_officer_type.index) and + (officer_type in HR_scaling_factor_by_district_and_officer_type.columns) + ): + self.module._daily_capabilities[clinic][officer] *= ( + HR_scaling_factor_by_district_and_officer_type.loc[district, officer_type] + ) + class RescaleHRCapabilities_ByDistrict(Event, PopulationScopeEventMixin): """This event exists to scale the daily capabilities, with a factor for each district.""" diff --git a/tests/test_healthburden.py b/tests/test_healthburden.py index 4b28b3fd85..799f1cedc6 100644 --- a/tests/test_healthburden.py +++ b/tests/test_healthburden.py @@ -77,17 +77,54 @@ def test_run_with_healthburden_with_dummy_diseases(tmpdir, seed): dalys = output['tlo.methods.healthburden']['dalys'] dalys = dalys.drop(columns=['date']) + # Columns that are not DALY causes + index_cols = ['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year'] + + # All remaining columns are DALY values + daly_cols = [c for c in dalys.columns if c not in index_cols] + + # Total national DALYs + national_totals = dalys[daly_cols].sum() + + # Total district DALYs + district_totals = (dalys.groupby('district_of_residence')[daly_cols].sum().sum()) + + pd.testing.assert_series_equal(national_totals.sort_index(), district_totals.sort_index(), check_dtype=False) + age_index = sim.modules['Demography'].AGE_RANGE_CATEGORIES sex_index = ['M', 'F'] + wealth_index = sim.modules['Lifestyle'].PROPERTIES['li_wealth'].categories + district_index = sim.modules['Demography'].PROPERTIES['district_of_residence'].categories year_index = list(range(start_date.year, end_date.year + 1)) - correct_multi_index = pd.MultiIndex.from_product([sex_index, age_index, year_index], - names=['sex', 'age_range', 'year']) - output_multi_index = dalys.set_index(['sex', 'age_range', 'year']).index + + correct_multi_index = pd.MultiIndex.from_product( + [sex_index, age_index, wealth_index, district_index, year_index], + names=['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year'] + ) + + output_multi_index = dalys.set_index( + ['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year']).index pd.testing.assert_index_equal(output_multi_index, correct_multi_index, check_order=False) + # Check total deaths in district are equal to total deaths at national level + yll = output['tlo.methods.healthburden']['yll_by_causes_of_death'] + yll = yll.drop(columns=['date']) + + index_cols = ['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year'] + death_cols = [c for c in yll.columns if c not in index_cols] + + # Total national deaths + national_deaths = yll[death_cols].sum() + + # Total district deaths + district_deaths = (yll.groupby('district_of_residence')[death_cols].sum().sum()) + + pd.testing.assert_series_equal(national_deaths.sort_index(), district_deaths.sort_index(), check_dtype=False) + # check that there is a column for each 'label' that is registered - assert set(dalys.set_index(['sex', 'age_range', 'year']).columns) == \ - {'Other', 'Mockitis_Disability_And_Death', 'ChronicSyndrome_Disability_And_Death'} + assert (set( + dalys.set_index(['sex', 'age_range', 'li_wealth', 'district_of_residence', 'year']).columns) == + {'Other', 'Mockitis_Disability_And_Death', 'ChronicSyndrome_Disability_And_Death'}) @pytest.mark.slow @@ -386,9 +423,12 @@ def test_airthmetic_of_lifeyearslost(seed, tmpdir): assert yll.sum().sum() == approx(1.0) # check that age-range is correct (0.5 ly lost among 0-4 year-olds; 0.5 ly lost to 5-9 year-olds) - assert yll.loc[('F', '0-4', slice(None), 2010)].sum().sum() == approx(0.5, abs=2.0 / DAYS_IN_YEAR) - assert yll.loc[('F', '5-9', slice(None), 2010)].sum().sum() == approx(0.5, abs=2.0 / DAYS_IN_YEAR) - assert yll.loc[('F', ['0-4', '5-9'], slice(None), 2010)].sum().sum() == approx(1.0, abs=0.5 / DAYS_IN_YEAR) + assert (yll.loc[('F', '0-4', slice(None), slice(None), 2010)].sum().sum() + == approx(0.5, abs=2.0 / DAYS_IN_YEAR)) + assert (yll.loc[('F', '5-9', slice(None), slice(None), 2010)].sum().sum() + == approx(0.5, abs=2.0 / DAYS_IN_YEAR)) + assert (yll.loc[('F', ['0-4', '5-9'], slice(None), slice(None), 2010)].sum().sum() + == approx(1.0, abs=0.5 / DAYS_IN_YEAR)) @pytest.mark.slow @@ -486,8 +526,8 @@ def apply(self, individual_id): & (yld.age_range == age_range_at_disability_onset) & (yld.sex == sex) ) - assert (yld.loc[marker_for_disability, 'cause_of_disability_A'] == daly_wt * 1.0).all() - assert (yld.loc[~marker_for_disability, 'cause_of_disability_A'] == 0.0).all() + assert (yld.loc[marker_for_disability, 'cause_of_disability_A'].sum() == approx(daly_wt * 1.0)) + assert (yld.loc[~marker_for_disability, 'cause_of_disability_A'].sum() == approx(0.0)) # For the Non-Stacked Results # -- YLL