diff --git a/src/eventdisplay_ml/data_processing.py b/src/eventdisplay_ml/data_processing.py index 3a49b43..f2246d8 100644 --- a/src/eventdisplay_ml/data_processing.py +++ b/src/eventdisplay_ml/data_processing.py @@ -730,14 +730,20 @@ def load_training_data(model_configs, file_list, analysis_type): observatory=model_configs.get("observatory", "veritas"), ) if analysis_type == "stereo_analysis": - df_flat["MCxoff"] = _to_numpy_1d(df["MCxoff"], np.float32) - df_flat["MCyoff"] = _to_numpy_1d(df["MCyoff"], np.float32) - df_flat["MCe0"] = np.log10(_to_numpy_1d(df["MCe0"], np.float32)) + new_cols = { + "MCxoff": _to_numpy_1d(df["MCxoff"], np.float32), + "MCyoff": _to_numpy_1d(df["MCyoff"], np.float32), + "MCe0": np.log10(_to_numpy_1d(df["MCe0"], np.float32)), + } elif analysis_type == "classification": - df_flat["ze_bin"] = zenith_in_bins( - 90.0 - _to_numpy_1d(df["ArrayPointing_Elevation"], np.float32), - model_configs.get("zenith_bins_deg", []), - ) + new_cols = { + "ze_bin": zenith_in_bins( + 90.0 - _to_numpy_1d(df["ArrayPointing_Elevation"], np.float32), + model_configs.get("zenith_bins_deg", []), + ) + } + for col_name, values in new_cols.items(): + df_flat[col_name] = values dfs.append(df_flat) diff --git a/src/eventdisplay_ml/evaluate.py b/src/eventdisplay_ml/evaluate.py index 30acd83..cd49336 100644 --- a/src/eventdisplay_ml/evaluate.py +++ b/src/eventdisplay_ml/evaluate.py @@ -68,7 +68,9 @@ def evaluate_classification_model(model, x_test, y_test, df, x_cols, name): shap_feature_importance(model, x_test, ["label"]) -def evaluate_regression_model(model, x_test, y_test, df, x_cols, y_data, name): +def evaluate_regression_model( + model, x_test, y_test, df, x_cols, y_data, name, shap_per_energy=False +): """Evaluate the trained model on the test set and log performance metrics.""" score = model.score(x_test, y_test) _logger.info(f"XGBoost Multi-Target R^2 Score (Testing Set): {score:.4f}") @@ -82,6 +84,8 @@ def evaluate_regression_model(model, x_test, y_test, df, x_cols, y_data, name): feature_importance(model, x_cols, y_data.columns, name) if name == "xgboost": shap_feature_importance(model, x_test, y_data.columns) + if shap_per_energy: + shap_feature_importance_by_energy(model, x_test, df, y_test, y_data.columns) df_pred = pd.DataFrame(y_pred, columns=target_features("stereo_analysis")) calculate_resolution( @@ -253,3 +257,147 @@ def shap_feature_importance(model, x_data, target_names, max_points=1000, n_top= for j in idx[:n_top]: if j < n_features: _logger.info(f"{x_data.columns[j]:25s} {imp[j]:.6e}") + + +def shap_feature_importance_by_energy( + model, + x_test, + df, + y_test, + target_names, + log_e_min=-2.0, + log_e_max=2.5, + n_bins=9, + max_points=1000, + n_top=5, +): + """Calculate SHAP feature importance for each energy bin. + + Computes SHAP values separately for events in different energy ranges, + allowing analysis of feature importance as a function of energy. + Uses the same energy binning as calculate_resolution for consistency. + Outputs results in tabular format for easy comparison across energy bins. + """ + # Extract energy values and create bins + mce0_values = df.loc[y_test.index, "MCe0"].values + bins = np.linspace(log_e_min, log_e_max, n_bins + 1) + # Use pd.cut with include_lowest=True to match calculate_resolution binning + bin_categories = pd.cut(mce0_values, bins=bins, include_lowest=True, right=True) + # Convert categorical bins to 1-based integer indices (NaN -> code -1, becomes 0) + bin_indices = bin_categories.cat.codes + 1 + + n_features = len(x_test.columns) + n_targets = len(target_names) + + # Store importance values for each target across all bins + target_importance_data = {target: {} for target in target_names} + bin_info = [] + + # Collect stratified samples for all bins, then compute SHAP once + sampled_frames = [] + sampled_bin_labels = [] + + for bin_idx in range(1, n_bins + 1): + mask = bin_indices == bin_idx + n_events = mask.sum() + + if n_events == 0: + continue + + bin_lower = bins[bin_idx - 1] + bin_upper = bins[bin_idx] + mean_log_e = mce0_values[mask].mean() + + # Use a stable, unique bin label based on the explicit energy range + bin_label = f"[{bin_lower:.2f}, {bin_upper:.2f}]" + bin_info.append( + { + "label": bin_label, + "mean_log_e": mean_log_e, + "n_events": n_events, + "range": bin_label, + } + ) + + x_bin = x_test.iloc[mask] + n_sample = min(len(x_bin), max_points) + x_sample = x_bin.sample(n=n_sample, random_state=None) + + sampled_frames.append(x_sample) + sampled_bin_labels.extend([bin_label] * len(x_sample)) + + if not sampled_frames: + _logger.info("No events found in any energy bin for SHAP calculation.") + return + + x_sampled_all = pd.concat(sampled_frames, axis=0) + dmatrix = xgb.DMatrix(x_sampled_all) + shap_vals = model.get_booster().predict(dmatrix, pred_contribs=True) + shap_vals = shap_vals.reshape(len(x_sampled_all), n_targets, n_features + 1) + + # Aggregate SHAP importance per bin from the single SHAP run + sampled_bin_labels = np.array(sampled_bin_labels) + for i, target in enumerate(target_names): + target_shap = shap_vals[:, i, :-1] + for info in bin_info: + bin_label = info["label"] + bin_mask = sampled_bin_labels == bin_label + if not np.any(bin_mask): + continue + + imp = np.abs(target_shap[bin_mask]).mean(axis=0) + for j, feature_name in enumerate(x_test.columns): + if feature_name not in target_importance_data[target]: + target_importance_data[target][feature_name] = {} + target_importance_data[target][feature_name][bin_label] = imp[j] + + # Create and display tables for each target + _logger.info(f"\n{'=' * 100}") + _logger.info("SHAP Feature Importance by Energy Bin (Tabular Format)") + _logger.info(f"Calculated over {n_bins} bins [{log_e_min}, {log_e_max}]") + _logger.info(f"{'=' * 100}") + + # Display bin information + _logger.info("\nEnergy Bin Information:") + for info in bin_info: + _logger.info(f" {info['label']:12s}: Range {info['range']:15s}, N = {info['n_events']:6d}") + + for target in target_names: + _logger.info(f"\n\n=== SHAP Importance for {target} ===") + + # Find top N features in each bin, then take union of all top features + all_top_features = set() + for info in bin_info: + bin_label = info["label"] + # Get importance values for this bin + bin_importance = { + feature: values.get(bin_label, 0) + for feature, values in target_importance_data[target].items() + } + # Get top N features for this bin + top_in_bin = sorted(bin_importance.items(), key=lambda x: x[1], reverse=True)[:n_top] + all_top_features.update([f[0] for f in top_in_bin]) + + # Sort features by their average importance across all bins + feature_avg_importance = {} + for feature_name in all_top_features: + values = [ + target_importance_data[target][feature_name].get(info["label"], 0) + for info in bin_info + ] + feature_avg_importance[feature_name] = np.mean(values) + + sorted_features = sorted(feature_avg_importance.items(), key=lambda x: x[1], reverse=True) + + # Build DataFrame with all features that were top N in at least one bin + data_rows = [] + for feature_name, _ in sorted_features: + row = {"Feature": feature_name} + for info in bin_info: + bin_label = info["label"] + value = target_importance_data[target][feature_name].get(bin_label, np.nan) + row[bin_label] = value + data_rows.append(row) + + df_table = pd.DataFrame(data_rows) + _logger.info(f"\n{df_table.to_markdown(index=False, floatfmt='.4e')}")