From 0305e90e79c652083091aab4beecbd480f1a681a Mon Sep 17 00:00:00 2001 From: monte-flora Date: Wed, 1 Apr 2026 19:05:48 +0000 Subject: [PATCH] Add SAGE integration and bump version to 1.0.0 - Add ExplainToolkit.sage() method for SAGE global feature importance (Covert et al. 2020). Supports individual and grouped SAGE via sage.MarginalImputer and sage.GroupedMarginalImputer. - Add sage-importance as optional dependency (pip install scikit-explain[sage]) - Add group_sage() utility for post-hoc grouping of SAGE values - Export group_local_values, group_feature_values, to_skexplain_importance, group_sage from skexplain top-level for advanced workflows - Fix to_skexplain_importance to handle grouped_sage method name - Fix group_local_values to handle list-type attrs (method, estimator name) - Add "grouped_sage" to PlotImportance SINGLE_VAR_METHODS and DISPLAY_NAMES_DICT - Add tutorial notebook: 13_sage_and_global_ranking.ipynb (SAGE computation, SHAP+SAGE side-by-side, grouped importance) - Bump version to 1.0.0 Co-Authored-By: Claude Opus 4.6 (1M context) --- setup.py | 3 +- skexplain/__init__.py | 6 +- skexplain/common/contrib_utils.py | 4 + skexplain/common/importance_utils.py | 51 ++- skexplain/main/_importance_mixin.py | 164 +++++++++- skexplain/plot/plot_permutation_importance.py | 4 +- .../13_sage_and_global_ranking.ipynb | 305 ++++++++++++++++++ 7 files changed, 528 insertions(+), 9 deletions(-) create mode 100644 tutorial_notebooks/13_sage_and_global_ranking.ipynb diff --git a/setup.py b/setup.py index 2523750..b42e522 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ EMAIL = "monte.flora@noaa.gov" AUTHOR = "Montgomery Flora" REQUIRES_PYTHON = ">=3.8.0" -VERSION = "0.1.7" +VERSION = "1.0.0" # What packages are required for this module to be executed? REQUIRED = [ @@ -48,6 +48,7 @@ # What packages are optional? EXTRAS = { "interactive": ["jupyter"], + "sage": ["sage-importance"], } if sys.platform == "darwin": diff --git a/skexplain/__init__.py b/skexplain/__init__.py index 3c50770..7b83727 100644 --- a/skexplain/__init__.py +++ b/skexplain/__init__.py @@ -10,4 +10,8 @@ from .common.models import load_models from .common.dataset import load_data -__version__ = "0.1.2" +# Import utilities for advanced workflows +from .common.importance_utils import to_skexplain_importance, group_sage +from .common.contrib_utils import group_local_values, group_feature_values + +__version__ = "1.0.0" diff --git a/skexplain/common/contrib_utils.py b/skexplain/common/contrib_utils.py index 4aee284..a0785ae 100644 --- a/skexplain/common/contrib_utils.py +++ b/skexplain/common/contrib_utils.py @@ -18,7 +18,11 @@ def group_local_values(explain_ds, groups, X, inds=None): inds = np.arange(len(X)) estimator_name = explain_ds.attrs["estimators used"] + if isinstance(estimator_name, (list, np.ndarray)): + estimator_name = estimator_name[0] method = explain_ds.attrs["method"] + if isinstance(method, (list, np.ndarray)): + method = method[0] explain_df = pd.DataFrame( explain_ds[f"{method}_values__{estimator_name}"], columns=explain_ds.attrs["features"] diff --git a/skexplain/common/importance_utils.py b/skexplain/common/importance_utils.py index c348cbf..1e6cfd2 100644 --- a/skexplain/common/importance_utils.py +++ b/skexplain/common/importance_utils.py @@ -227,7 +227,8 @@ def to_skexplain_importance( This is useful when comparing importance across different methods. """ bootstrap = False - if method == "sage": + importances_std = None + if method in ("sage", "grouped_sage") and hasattr(importances, "values") and hasattr(importances, "std"): importances_std = importances.std importances = importances.values elif method == "coefs": @@ -267,7 +268,7 @@ def to_skexplain_importance( else: scores_ranked = importances[ranked_indices] - if method == "sage": + if method in ("sage", "grouped_sage") and importances_std is not None: std_ranked = importances_std[ranked_indices] features_ranked = np.array(feature_names)[ranked_indices] @@ -293,9 +294,9 @@ def to_skexplain_importance( scores_ranked, ) - if method == "sage": - data[f"sage_scores_std__{estimator_name}"] = ( - [f"n_vars_sage"], + if method in ("sage", "grouped_sage") and importances_std is not None: + data[f"{method}_scores_std__{estimator_name}"] = ( + [f"n_vars_{method}"], std_ranked, ) @@ -307,6 +308,46 @@ def to_skexplain_importance( return data +def group_sage(sage_results, groups, estimator_name=None): + """Group SAGE importance values by feature groups. + + Parameters + ---------- + sage_results : xarray.Dataset + Results from ``ExplainToolkit.sage()`` or ``to_skexplain_importance`` + with method='sage'. + groups : dict + Feature groups. Keys are group names, values are lists of feature names. + estimator_name : str, optional + Estimator name. If None, inferred from the dataset. + + Returns + ------- + xarray.Dataset + Grouped SAGE importance values. + """ + if estimator_name is None: + estimator_name = list(sage_results.data_vars)[0].split("__")[-1] + + features = list(sage_results[f"sage_rankings__{estimator_name}"].values) + scores = sage_results[f"sage_scores__{estimator_name}"].values.flatten() + + group_vals = np.zeros(len(groups)) + group_names = [] + for i, (group_name, group_features) in enumerate(groups.items()): + indices = [features.index(f) for f in group_features if f in features] + group_vals[i] = np.sum(scores[indices]) + group_names.append(group_name) + + return to_skexplain_importance( + group_vals, + estimator_name=estimator_name, + feature_names=group_names, + method="grouped_sage", + normalize=False, + ) + + def combine_top_features(results_dict, n_vars=None): """Combines the list of top features from different estimators into a single list where duplicates are removed. diff --git a/skexplain/main/_importance_mixin.py b/skexplain/main/_importance_mixin.py index bf8f17a..5bdf75d 100644 --- a/skexplain/main/_importance_mixin.py +++ b/skexplain/main/_importance_mixin.py @@ -1,8 +1,12 @@ import itertools +import numpy as np import xarray as xr from ..common.utils import is_str, to_xarray, check_all_features_for_ale -from ..common.importance_utils import retrieve_important_vars, combine_top_features, compute_importance +from ..common.importance_utils import ( + retrieve_important_vars, combine_top_features, compute_importance, + to_skexplain_importance, +) from ._validation import normalize_features, normalize_estimator_names, track_timing @@ -547,3 +551,161 @@ def get_important_vars(self, perm_imp_data, multipass=True, n_vars=10, combine=F return results else: return combine_top_features(results, n_vars=n_vars) + + @track_timing + def sage( + self, + background=None, + groups=None, + n_background=50, + n_jobs=1, + random_state=42, + loss=None, + **sage_kws, + ): + """ + Compute SAGE (Shapley Additive Global importancE) values [16]_. + + SAGE measures each feature's global importance by estimating its + contribution to model performance using Shapley values. Unlike + permutation importance, SAGE properly accounts for feature interactions. + + Requires the optional ``sage-importance`` package:: + + pip install sage-importance + + Parameters + ---------- + background : array-like, optional + Background dataset for the marginal imputer. If None, uses ``self.X``. + + groups : dict, optional + Feature groups for grouped SAGE. Keys are group names, values are lists + of feature names. When provided, uses ``sage.GroupedMarginalImputer``. + E.g., ``{'temperature': ['temp2m', 'sfc_temp'], 'wind': ['wind10m', 'fric_vel']}`` + + n_background : int, default=50 + Number of random samples from the background data to use for the imputer. + + n_jobs : int, default=1 + Number of parallel jobs for the SAGE estimator. + + random_state : int, default=42 + Random seed for reproducibility. + + loss : str, optional + Loss function for the SAGE estimator. If None, auto-detected: + ``'cross entropy'`` for classifiers, ``'mse'`` for regressors. + + **sage_kws + Additional keyword arguments passed to ``sage.PermutationEstimator.__call__``. + E.g., ``batch_size``, ``detect_convergence``, ``thresh``, ``n_permutations``. + + Returns + ------- + results : xarray.Dataset + Dataset with SAGE importance rankings and scores for each estimator. + Variables: ``sage_rankings__{est_name}``, ``sage_scores__{est_name}``, + ``sage_scores_std__{est_name}``. + + When ``groups`` is provided, uses method name ``grouped_sage``. + + References + ---------- + .. [16] Covert, I., Lundberg, S., and Lee, S.-I., 2020: + Understanding Global Feature Contributions With Additive + Importance Measures. NeurIPS. + + Examples + -------- + >>> import skexplain + >>> estimators = skexplain.load_models() + >>> X, y = skexplain.load_data() + >>> explainer = skexplain.ExplainToolkit(estimators=estimators, X=X, y=y) + >>> sage_results = explainer.sage() + >>> explainer.plot_importance( + ... data=sage_results, + ... panels=[('sage', 'Random Forest')], + ... ) + """ + try: + import sage + except ImportError: + raise ImportError( + "The 'sage-importance' package is required for SAGE computation. " + "Install it with: pip install sage-importance" + ) + + if background is None: + background = self.X + + rs = np.random.RandomState(random_state) + n_bg = min(n_background, len(background)) + random_inds = rs.choice(len(background), size=n_bg, replace=False) + try: + X_bg = background.values[random_inds, :] + except AttributeError: + X_bg = background[random_inds, :] + + method_name = "grouped_sage" if groups is not None else "sage" + + results_list = [] + for estimator_name, estimator in self.estimators.items(): + # Determine model function and loss + if loss is not None: + loss_ = loss + elif hasattr(estimator, "predict_proba"): + loss_ = "cross entropy" + else: + loss_ = "mse" + + model_fn = ( + estimator.predict_proba + if hasattr(estimator, "predict_proba") + else estimator.predict + ) + + # Set up the imputer + if groups is not None: + # Convert group names → list of index lists + group_indices = [ + [self.feature_names.index(f) for f in feats] + for feats in groups.values() + ] + imputer = sage.GroupedMarginalImputer(model_fn, X_bg, group_indices) + feature_names = list(groups.keys()) + else: + imputer = sage.MarginalImputer(model_fn, X_bg) + feature_names = self.feature_names + + # Compute SAGE + estimator_sage = sage.PermutationEstimator( + imputer, loss_, n_jobs=n_jobs, random_state=rs, + ) + + try: + X_vals = self.X.values + except AttributeError: + X_vals = self.X + + sage_values = estimator_sage(X_vals, self.y, **sage_kws) + + # Convert to skexplain format + result_ds = to_skexplain_importance( + sage_values, + estimator_name=estimator_name, + feature_names=feature_names, + method=method_name, + normalize=False, + ) + results_list.append(result_ds) + + # Merge results from all estimators + results_ds = xr.merge(results_list, combine_attrs="override") + + self.attrs_dict["method"] = method_name + if groups is not None: + self.attrs_dict["feature_groups"] = {k: list(v) for k, v in groups.items()} + results_ds = self._append_attributes(results_ds) + + return results_ds diff --git a/skexplain/plot/plot_permutation_importance.py b/skexplain/plot/plot_permutation_importance.py index e391b15..8592757 100644 --- a/skexplain/plot/plot_permutation_importance.py +++ b/skexplain/plot/plot_permutation_importance.py @@ -25,6 +25,7 @@ class PlotImportance(PlotStructure): "gini", "combined", "sage", + "grouped_sage", "grouped", "grouped_only", "lime", @@ -47,7 +48,8 @@ class PlotImportance(PlotStructure): "hstat": "H-Stat", "gini": "Gini", "combined": "Method-Average Ranking", - "sage": "SAGE Importance Scores", + "sage": "SAGE Importance", + "grouped_sage": "Grouped SAGE Importance", "grouped": "Grouped Importance", "grouped_only": "Grouped Only Importance", "sobol_total": "Sobol Total", diff --git a/tutorial_notebooks/13_sage_and_global_ranking.ipynb b/tutorial_notebooks/13_sage_and_global_ranking.ipynb new file mode 100644 index 0000000..f525b65 --- /dev/null +++ b/tutorial_notebooks/13_sage_and_global_ranking.ipynb @@ -0,0 +1,305 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SAGE and Global Feature Ranking\n", + "\n", + "This notebook demonstrates **SAGE** (Shapley Additive Global importancE) — a Shapley-based\n", + "method for global feature importance that properly accounts for feature interactions.\n", + "\n", + "We also show how to:\n", + "- Compare SAGE with SHAP-based importance and permutation importance side-by-side\n", + "- Group features for grouped SAGE and grouped SHAP\n", + "\n", + "SAGE requires the optional `sage-importance` package:\n", + "```\n", + "pip install sage-importance\n", + "```\n", + "\n", + "**Reference:** Covert, I., Lundberg, S., and Lee, S.-I., 2020: Understanding Global Feature\n", + "Contributions With Additive Importance Measures. NeurIPS." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.ensemble import GradientBoostingClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "import skexplain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Synthetic Dataset\n", + "\n", + "A weather-inspired binary classification task with known feature importance structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "N = 2000\n", + "X = pd.DataFrame({\n", + " 'CAPE': np.random.exponential(1500, N),\n", + " 'Shear': np.random.gamma(3, 5, N),\n", + " 'Freezing_Lvl': 2500 + np.random.randn(N) * 500,\n", + " 'Moisture': np.random.beta(3, 2, N) * 20,\n", + " 'Temperature': 25 + np.random.randn(N) * 8,\n", + " 'Noise': np.random.randn(N) * 10,\n", + "})\n", + "\n", + "logit = 0.002*X['CAPE'] + 0.08*X['Shear'] - 0.001*X['Freezing_Lvl'] + 0.05*X['Moisture'] - 5.0\n", + "y = (np.random.rand(N) < 1/(1+np.exp(-logit))).astype(int)\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)\n", + "X_test = X_test.reset_index(drop=True)\n", + "\n", + "gb = GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=42)\n", + "gb.fit(X_train, y_train)\n", + "print(f'Test accuracy: {gb.score(X_test, y_test):.3f}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute SAGE Values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "explainer = skexplain.ExplainToolkit(\n", + " estimators=[('GB', gb)],\n", + " X=X_test, y=y_test,\n", + ")\n", + "\n", + "sage_results = explainer.sage(n_background=50, n_jobs=1)\n", + "print('SAGE rankings:', sage_results['sage_rankings__GB'].values)\n", + "print(f'Computation time: {sage_results.attrs[\"computation_time_seconds\"]}s')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot SAGE Importance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = explainer.plot_importance(\n", + " data=sage_results,\n", + " panels=[('sage', 'GB')],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare SAGE vs SHAP vs Permutation Importance\n", + "\n", + "Compute all three global ranking methods and plot them side-by-side." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute SHAP and convert to importance\n", + "shap_results = explainer.local_attributions(method='shap')\n", + "shap_importance = skexplain.to_skexplain_importance(\n", + " shap_results['shap_values__GB'].values,\n", + " estimator_name='GB',\n", + " feature_names=list(X_test.columns),\n", + " method='shap_sum',\n", + ")\n", + "\n", + "# Compute permutation importance\n", + "perm_imp = explainer.permutation_importance(\n", + " n_vars=6, evaluation_fn='auc', n_permute=5,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot all three side by side\n", + "fig, axes = explainer.plot_importance(\n", + " data=[sage_results, shap_importance, perm_imp],\n", + " panels=[\n", + " ('sage', 'GB'),\n", + " ('shap_sum', 'GB'),\n", + " ('backward_multipass', 'GB'),\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SHAP Summary + SAGE Side-by-Side\n", + "\n", + "A common visualization pattern: SHAP beeswarm (feature relevance) next to SAGE bars (feature importance)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(dpi=300, ncols=2, figsize=(12, 6))\n", + "\n", + "# Left panel: SHAP summary plot\n", + "explainer.scatter_plot(\n", + " dataset=shap_results,\n", + " estimator_name='GB',\n", + " method='shap',\n", + " plot_type='summary',\n", + " ax=axes[0],\n", + " fig=fig,\n", + " add_colorbar=False,\n", + " max_display=6,\n", + ")\n", + "axes[0].set_xlabel('SHAP Value')\n", + "axes[0].set_title('Feature Relevance (SHAP)', fontsize=12)\n", + "\n", + "# Right panel: SAGE importance\n", + "explainer.plot_importance(\n", + " data=sage_results,\n", + " panels=[('sage', 'GB')],\n", + " ax=axes[1],\n", + " xlabels=['SAGE Value'],\n", + " show_method_subtitle=False,\n", + ")\n", + "axes[1].set_title('Feature Importance (SAGE)', fontsize=12)\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Grouped SAGE\n", + "\n", + "Group features into categories and compute SAGE at the group level.\n", + "This can be done in two ways:\n", + "1. **Directly** via `explainer.sage(groups=...)` — groups features during computation\n", + "2. **Post-hoc** via `skexplain.group_sage(sage_results, groups)` — sums individual SAGE values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "groups = {\n", + " 'Thermodynamic': ['CAPE', 'Moisture', 'Temperature'],\n", + " 'Kinematic': ['Shear'],\n", + " 'Environmental': ['Freezing_Lvl', 'Noise'],\n", + "}\n", + "\n", + "# Method 1: Direct grouped computation\n", + "grouped_sage = explainer.sage(groups=groups, n_background=50, n_jobs=1)\n", + "print('Direct grouped rankings:', grouped_sage['grouped_sage_rankings__GB'].values)\n", + "\n", + "# Method 2: Post-hoc grouping\n", + "grouped_post = skexplain.group_sage(sage_results, groups, estimator_name='GB')\n", + "print('Post-hoc grouped rankings:', grouped_post['grouped_sage_rankings__GB'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = explainer.plot_importance(\n", + " data=[grouped_sage, grouped_post],\n", + " panels=[\n", + " ('grouped_sage', 'GB'),\n", + " ('grouped_sage', 'GB'),\n", + " ],\n", + " xlabels=['Direct Grouped SAGE', 'Post-hoc Grouped SAGE'],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Grouped SHAP\n", + "\n", + "You can also group SHAP values for grouped beeswarm plots using the\n", + "built-in utilities `group_local_values` and `group_feature_values`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Group SHAP values\n", + "X_grouped = skexplain.group_feature_values(X_test, groups)\n", + "grouped_shap = skexplain.group_local_values(shap_results, groups, X_grouped)\n", + "\n", + "# Create explainer with grouped features\n", + "explainer_grouped = skexplain.ExplainToolkit(X=X_grouped)\n", + "\n", + "# Plot grouped SHAP summary\n", + "explainer_grouped.scatter_plot(\n", + " dataset=grouped_shap,\n", + " estimator_name='GB',\n", + " method='shap',\n", + " plot_type='summary',\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}