Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/eventdisplay_ml/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,14 +730,20 @@ def load_training_data(model_configs, file_list, analysis_type):
observatory=model_configs.get("observatory", "veritas"),
)
if analysis_type == "stereo_analysis":
df_flat["MCxoff"] = _to_numpy_1d(df["MCxoff"], np.float32)
df_flat["MCyoff"] = _to_numpy_1d(df["MCyoff"], np.float32)
df_flat["MCe0"] = np.log10(_to_numpy_1d(df["MCe0"], np.float32))
new_cols = {
"MCxoff": _to_numpy_1d(df["MCxoff"], np.float32),
"MCyoff": _to_numpy_1d(df["MCyoff"], np.float32),
"MCe0": np.log10(_to_numpy_1d(df["MCe0"], np.float32)),
}
elif analysis_type == "classification":
df_flat["ze_bin"] = zenith_in_bins(
90.0 - _to_numpy_1d(df["ArrayPointing_Elevation"], np.float32),
model_configs.get("zenith_bins_deg", []),
)
new_cols = {
"ze_bin": zenith_in_bins(
90.0 - _to_numpy_1d(df["ArrayPointing_Elevation"], np.float32),
model_configs.get("zenith_bins_deg", []),
)
}
for col_name, values in new_cols.items():
df_flat[col_name] = values

dfs.append(df_flat)

Expand Down
150 changes: 149 additions & 1 deletion src/eventdisplay_ml/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def evaluate_classification_model(model, x_test, y_test, df, x_cols, name):
shap_feature_importance(model, x_test, ["label"])


def evaluate_regression_model(model, x_test, y_test, df, x_cols, y_data, name):
def evaluate_regression_model(
model, x_test, y_test, df, x_cols, y_data, name, shap_per_energy=False
):
"""Evaluate the trained model on the test set and log performance metrics."""
score = model.score(x_test, y_test)
_logger.info(f"XGBoost Multi-Target R^2 Score (Testing Set): {score:.4f}")
Expand All @@ -82,6 +84,8 @@ def evaluate_regression_model(model, x_test, y_test, df, x_cols, y_data, name):
feature_importance(model, x_cols, y_data.columns, name)
if name == "xgboost":
shap_feature_importance(model, x_test, y_data.columns)
if shap_per_energy:
shap_feature_importance_by_energy(model, x_test, df, y_test, y_data.columns)

df_pred = pd.DataFrame(y_pred, columns=target_features("stereo_analysis"))
calculate_resolution(
Expand Down Expand Up @@ -253,3 +257,147 @@ def shap_feature_importance(model, x_data, target_names, max_points=1000, n_top=
for j in idx[:n_top]:
if j < n_features:
_logger.info(f"{x_data.columns[j]:25s} {imp[j]:.6e}")


def shap_feature_importance_by_energy(
model,
x_test,
df,
y_test,
target_names,
log_e_min=-2.0,
log_e_max=2.5,
n_bins=9,
max_points=1000,
n_top=5,
):
"""Calculate SHAP feature importance for each energy bin.

Computes SHAP values separately for events in different energy ranges,
allowing analysis of feature importance as a function of energy.
Uses the same energy binning as calculate_resolution for consistency.
Outputs results in tabular format for easy comparison across energy bins.
"""
# Extract energy values and create bins
mce0_values = df.loc[y_test.index, "MCe0"].values
bins = np.linspace(log_e_min, log_e_max, n_bins + 1)
# Use pd.cut with include_lowest=True to match calculate_resolution binning
bin_categories = pd.cut(mce0_values, bins=bins, include_lowest=True, right=True)
# Convert categorical bins to 1-based integer indices (NaN -> code -1, becomes 0)
bin_indices = bin_categories.cat.codes + 1

n_features = len(x_test.columns)
n_targets = len(target_names)

# Store importance values for each target across all bins
target_importance_data = {target: {} for target in target_names}
bin_info = []

# Collect stratified samples for all bins, then compute SHAP once
sampled_frames = []
sampled_bin_labels = []

for bin_idx in range(1, n_bins + 1):
mask = bin_indices == bin_idx
n_events = mask.sum()

if n_events == 0:
continue

bin_lower = bins[bin_idx - 1]
bin_upper = bins[bin_idx]
mean_log_e = mce0_values[mask].mean()

# Use a stable, unique bin label based on the explicit energy range
bin_label = f"[{bin_lower:.2f}, {bin_upper:.2f}]"
bin_info.append(
{
"label": bin_label,
"mean_log_e": mean_log_e,
"n_events": n_events,
"range": bin_label,
}
)

x_bin = x_test.iloc[mask]
n_sample = min(len(x_bin), max_points)
x_sample = x_bin.sample(n=n_sample, random_state=None)

sampled_frames.append(x_sample)
sampled_bin_labels.extend([bin_label] * len(x_sample))

if not sampled_frames:
_logger.info("No events found in any energy bin for SHAP calculation.")
return

x_sampled_all = pd.concat(sampled_frames, axis=0)
dmatrix = xgb.DMatrix(x_sampled_all)
shap_vals = model.get_booster().predict(dmatrix, pred_contribs=True)
shap_vals = shap_vals.reshape(len(x_sampled_all), n_targets, n_features + 1)

# Aggregate SHAP importance per bin from the single SHAP run
sampled_bin_labels = np.array(sampled_bin_labels)
for i, target in enumerate(target_names):
target_shap = shap_vals[:, i, :-1]
for info in bin_info:
bin_label = info["label"]
bin_mask = sampled_bin_labels == bin_label
if not np.any(bin_mask):
continue

imp = np.abs(target_shap[bin_mask]).mean(axis=0)
for j, feature_name in enumerate(x_test.columns):
if feature_name not in target_importance_data[target]:
target_importance_data[target][feature_name] = {}
target_importance_data[target][feature_name][bin_label] = imp[j]

# Create and display tables for each target
_logger.info(f"\n{'=' * 100}")
_logger.info("SHAP Feature Importance by Energy Bin (Tabular Format)")
_logger.info(f"Calculated over {n_bins} bins [{log_e_min}, {log_e_max}]")
_logger.info(f"{'=' * 100}")

# Display bin information
_logger.info("\nEnergy Bin Information:")
for info in bin_info:
_logger.info(f" {info['label']:12s}: Range {info['range']:15s}, N = {info['n_events']:6d}")

for target in target_names:
_logger.info(f"\n\n=== SHAP Importance for {target} ===")

# Find top N features in each bin, then take union of all top features
all_top_features = set()
for info in bin_info:
bin_label = info["label"]
# Get importance values for this bin
bin_importance = {
feature: values.get(bin_label, 0)
for feature, values in target_importance_data[target].items()
}
# Get top N features for this bin
top_in_bin = sorted(bin_importance.items(), key=lambda x: x[1], reverse=True)[:n_top]
all_top_features.update([f[0] for f in top_in_bin])

# Sort features by their average importance across all bins
feature_avg_importance = {}
for feature_name in all_top_features:
values = [
target_importance_data[target][feature_name].get(info["label"], 0)
for info in bin_info
]
feature_avg_importance[feature_name] = np.mean(values)

sorted_features = sorted(feature_avg_importance.items(), key=lambda x: x[1], reverse=True)

# Build DataFrame with all features that were top N in at least one bin
data_rows = []
for feature_name, _ in sorted_features:
row = {"Feature": feature_name}
for info in bin_info:
bin_label = info["label"]
value = target_importance_data[target][feature_name].get(bin_label, np.nan)
row[bin_label] = value
data_rows.append(row)

df_table = pd.DataFrame(data_rows)
_logger.info(f"\n{df_table.to_markdown(index=False, floatfmt='.4e')}")