diff --git a/src/causal_validation/plotters.py b/src/causal_validation/plotters.py index b588350..390c301 100644 --- a/src/causal_validation/plotters.py +++ b/src/causal_validation/plotters.py @@ -45,7 +45,7 @@ def plot( else cols[1 + i % (len(cols) - 2)] ) unit_label = ( - "Treated" if len(data.treated_unit_indices) == 1 else f"Treated {unit_idx}" + "Treated" if len(data.treated_unit_indices) == 1 else f"Treated {i + 1}" ) ax.plot(idx, Y_treated[:, i], color=unit_color, label=unit_label) @@ -55,7 +55,7 @@ def plot( line_label = ( "Intervention" if len(data.treated_unit_indices) == 1 - else f"Intervention {unit_idx}" + else f"Intervention {i + 1}" ) ax.axvline( x=treatment_date,