Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
EMAIL = "monte.flora@noaa.gov"
AUTHOR = "Montgomery Flora"
REQUIRES_PYTHON = ">=3.8.0"
VERSION = "0.1.7"
VERSION = "1.0.0"

# What packages are required for this module to be executed?
REQUIRED = [
Expand All @@ -48,6 +48,7 @@
# What packages are optional?
EXTRAS = {
"interactive": ["jupyter"],
"sage": ["sage-importance"],
}

if sys.platform == "darwin":
Expand Down
6 changes: 5 additions & 1 deletion skexplain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@
from .common.models import load_models
from .common.dataset import load_data

__version__ = "0.1.2"
# Import utilities for advanced workflows
from .common.importance_utils import to_skexplain_importance, group_sage
from .common.contrib_utils import group_local_values, group_feature_values

__version__ = "1.0.0"
4 changes: 4 additions & 0 deletions skexplain/common/contrib_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def group_local_values(explain_ds, groups, X, inds=None):
inds = np.arange(len(X))

estimator_name = explain_ds.attrs["estimators used"]
if isinstance(estimator_name, (list, np.ndarray)):
estimator_name = estimator_name[0]
method = explain_ds.attrs["method"]
if isinstance(method, (list, np.ndarray)):
method = method[0]

explain_df = pd.DataFrame(
explain_ds[f"{method}_values__{estimator_name}"], columns=explain_ds.attrs["features"]
Expand Down
51 changes: 46 additions & 5 deletions skexplain/common/importance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def to_skexplain_importance(
This is useful when comparing importance across different methods.
"""
bootstrap = False
if method == "sage":
importances_std = None
if method in ("sage", "grouped_sage") and hasattr(importances, "values") and hasattr(importances, "std"):
importances_std = importances.std
importances = importances.values
elif method == "coefs":
Expand Down Expand Up @@ -267,7 +268,7 @@ def to_skexplain_importance(
else:
scores_ranked = importances[ranked_indices]

if method == "sage":
if method in ("sage", "grouped_sage") and importances_std is not None:
std_ranked = importances_std[ranked_indices]

features_ranked = np.array(feature_names)[ranked_indices]
Expand All @@ -293,9 +294,9 @@ def to_skexplain_importance(
scores_ranked,
)

if method == "sage":
data[f"sage_scores_std__{estimator_name}"] = (
[f"n_vars_sage"],
if method in ("sage", "grouped_sage") and importances_std is not None:
data[f"{method}_scores_std__{estimator_name}"] = (
[f"n_vars_{method}"],
std_ranked,
)

Expand All @@ -307,6 +308,46 @@ def to_skexplain_importance(
return data


def group_sage(sage_results, groups, estimator_name=None):
"""Group SAGE importance values by feature groups.

Parameters
----------
sage_results : xarray.Dataset
Results from ``ExplainToolkit.sage()`` or ``to_skexplain_importance``
with method='sage'.
groups : dict
Feature groups. Keys are group names, values are lists of feature names.
estimator_name : str, optional
Estimator name. If None, inferred from the dataset.

Returns
-------
xarray.Dataset
Grouped SAGE importance values.
"""
if estimator_name is None:
estimator_name = list(sage_results.data_vars)[0].split("__")[-1]

features = list(sage_results[f"sage_rankings__{estimator_name}"].values)
scores = sage_results[f"sage_scores__{estimator_name}"].values.flatten()

group_vals = np.zeros(len(groups))
group_names = []
for i, (group_name, group_features) in enumerate(groups.items()):
indices = [features.index(f) for f in group_features if f in features]
group_vals[i] = np.sum(scores[indices])
group_names.append(group_name)

return to_skexplain_importance(
group_vals,
estimator_name=estimator_name,
feature_names=group_names,
method="grouped_sage",
normalize=False,
)


def combine_top_features(results_dict, n_vars=None):
"""Combines the list of top features from different estimators
into a single list where duplicates are removed.
Expand Down
164 changes: 163 additions & 1 deletion skexplain/main/_importance_mixin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import itertools
import numpy as np
import xarray as xr

from ..common.utils import is_str, to_xarray, check_all_features_for_ale
from ..common.importance_utils import retrieve_important_vars, combine_top_features, compute_importance
from ..common.importance_utils import (
retrieve_important_vars, combine_top_features, compute_importance,
to_skexplain_importance,
)
from ._validation import normalize_features, normalize_estimator_names, track_timing


Expand Down Expand Up @@ -547,3 +551,161 @@ def get_important_vars(self, perm_imp_data, multipass=True, n_vars=10, combine=F
return results
else:
return combine_top_features(results, n_vars=n_vars)

@track_timing
def sage(
self,
background=None,
groups=None,
n_background=50,
n_jobs=1,
random_state=42,
loss=None,
**sage_kws,
):
"""
Compute SAGE (Shapley Additive Global importancE) values [16]_.

SAGE measures each feature's global importance by estimating its
contribution to model performance using Shapley values. Unlike
permutation importance, SAGE properly accounts for feature interactions.

Requires the optional ``sage-importance`` package::

pip install sage-importance

Parameters
----------
background : array-like, optional
Background dataset for the marginal imputer. If None, uses ``self.X``.

groups : dict, optional
Feature groups for grouped SAGE. Keys are group names, values are lists
of feature names. When provided, uses ``sage.GroupedMarginalImputer``.
E.g., ``{'temperature': ['temp2m', 'sfc_temp'], 'wind': ['wind10m', 'fric_vel']}``

n_background : int, default=50
Number of random samples from the background data to use for the imputer.

n_jobs : int, default=1
Number of parallel jobs for the SAGE estimator.

random_state : int, default=42
Random seed for reproducibility.

loss : str, optional
Loss function for the SAGE estimator. If None, auto-detected:
``'cross entropy'`` for classifiers, ``'mse'`` for regressors.

**sage_kws
Additional keyword arguments passed to ``sage.PermutationEstimator.__call__``.
E.g., ``batch_size``, ``detect_convergence``, ``thresh``, ``n_permutations``.

Returns
-------
results : xarray.Dataset
Dataset with SAGE importance rankings and scores for each estimator.
Variables: ``sage_rankings__{est_name}``, ``sage_scores__{est_name}``,
``sage_scores_std__{est_name}``.

When ``groups`` is provided, uses method name ``grouped_sage``.

References
----------
.. [16] Covert, I., Lundberg, S., and Lee, S.-I., 2020:
Understanding Global Feature Contributions With Additive
Importance Measures. NeurIPS.

Examples
--------
>>> import skexplain
>>> estimators = skexplain.load_models()
>>> X, y = skexplain.load_data()
>>> explainer = skexplain.ExplainToolkit(estimators=estimators, X=X, y=y)
>>> sage_results = explainer.sage()
>>> explainer.plot_importance(
... data=sage_results,
... panels=[('sage', 'Random Forest')],
... )
"""
try:
import sage
except ImportError:
raise ImportError(
"The 'sage-importance' package is required for SAGE computation. "
"Install it with: pip install sage-importance"
)

if background is None:
background = self.X

rs = np.random.RandomState(random_state)
n_bg = min(n_background, len(background))
random_inds = rs.choice(len(background), size=n_bg, replace=False)
try:
X_bg = background.values[random_inds, :]
except AttributeError:
X_bg = background[random_inds, :]

method_name = "grouped_sage" if groups is not None else "sage"

results_list = []
for estimator_name, estimator in self.estimators.items():
# Determine model function and loss
if loss is not None:
loss_ = loss
elif hasattr(estimator, "predict_proba"):
loss_ = "cross entropy"
else:
loss_ = "mse"

model_fn = (
estimator.predict_proba
if hasattr(estimator, "predict_proba")
else estimator.predict
)

# Set up the imputer
if groups is not None:
# Convert group names → list of index lists
group_indices = [
[self.feature_names.index(f) for f in feats]
for feats in groups.values()
]
imputer = sage.GroupedMarginalImputer(model_fn, X_bg, group_indices)
feature_names = list(groups.keys())
else:
imputer = sage.MarginalImputer(model_fn, X_bg)
feature_names = self.feature_names

# Compute SAGE
estimator_sage = sage.PermutationEstimator(
imputer, loss_, n_jobs=n_jobs, random_state=rs,
)

try:
X_vals = self.X.values
except AttributeError:
X_vals = self.X

sage_values = estimator_sage(X_vals, self.y, **sage_kws)

# Convert to skexplain format
result_ds = to_skexplain_importance(
sage_values,
estimator_name=estimator_name,
feature_names=feature_names,
method=method_name,
normalize=False,
)
results_list.append(result_ds)

# Merge results from all estimators
results_ds = xr.merge(results_list, combine_attrs="override")

self.attrs_dict["method"] = method_name
if groups is not None:
self.attrs_dict["feature_groups"] = {k: list(v) for k, v in groups.items()}
results_ds = self._append_attributes(results_ds)

return results_ds
4 changes: 3 additions & 1 deletion skexplain/plot/plot_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class PlotImportance(PlotStructure):
"gini",
"combined",
"sage",
"grouped_sage",
"grouped",
"grouped_only",
"lime",
Expand All @@ -47,7 +48,8 @@ class PlotImportance(PlotStructure):
"hstat": "H-Stat",
"gini": "Gini",
"combined": "Method-Average Ranking",
"sage": "SAGE Importance Scores",
"sage": "SAGE Importance",
"grouped_sage": "Grouped SAGE Importance",
"grouped": "Grouped Importance",
"grouped_only": "Grouped Only Importance",
"sobol_total": "Sobol Total",
Expand Down
Loading
Loading