From 4653dc4e5bd5925c77c7ad3c58c175d4aba4e316 Mon Sep 17 00:00:00 2001 From: monte-flora Date: Thu, 2 Apr 2026 11:48:16 +0000 Subject: [PATCH] Fix categorical ALE crash with high-cardinality features (#86) When bootstrapping categorical ALE, different bootstrap samples may contain different subsets of categories, producing ragged arrays that numpy cannot stack. Fix by: - Computing the full category set from the original data upfront - Reindexing each bootstrap's ALE values to the full category set (missing categories filled with NaN) - This ensures all bootstrap iterations have consistent array shapes Also add 3 unit tests for low/medium/high cardinality categorical ALE. Closes #86 Co-Authored-By: Claude Opus 4.6 (1M context) --- skexplain/main/global_explainer.py | 16 ++++++++++- tests/test_interpret_curves.py | 46 +++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/skexplain/main/global_explainer.py b/skexplain/main/global_explainer.py index ef8afef..d30a716 100644 --- a/skexplain/main/global_explainer.py +++ b/skexplain/main/global_explainer.py @@ -1228,6 +1228,16 @@ def feature_encoder_func(data): xdata = np.array([np.unique(original_feature_values)]) xdata.sort() + # Determine the full set of categories from the original data + # so all bootstrap iterations produce ALE values aligned to the same categories. + X_full = self.X.copy() + if (X_full[feature].dtype.name != "category") or (not X_full[feature].cat.ordered): + X_full[feature] = X_full[feature].astype("string") + full_groups_order = order_groups(X_full, feature) + all_categories = full_groups_order.index.values + else: + all_categories = X_full[feature].cat.categories.values + # Initialize an empty ale array ale = [] @@ -1355,7 +1365,11 @@ def feature_encoder_func(data): # Subtract the mean value to get the centered value. ale_temp = res_df["ale"] - sum(res_df["ale"] * groups_props) - ale.append(ale_temp) + + # Reindex to the full set of categories so all bootstrap + # iterations have the same length. Missing categories get NaN. + ale_temp = ale_temp.reindex(all_categories, fill_value=np.nan) + ale.append(ale_temp.values) ale = np.array(ale, dtype=float) diff --git a/tests/test_interpret_curves.py b/tests/test_interpret_curves.py index 41074b8..83a0abc 100644 --- a/tests/test_interpret_curves.py +++ b/tests/test_interpret_curves.py @@ -2,9 +2,10 @@ # Unit test for the ALE and PD # code in Scikit-Explain. #=================================================== -import sys, os +import sys, os sys.path.insert(0, os.path.dirname(os.getcwd())) +import unittest import shap import numpy as np import skexplain @@ -239,5 +240,48 @@ def test_mec(self): self.assertAlmostEqual(mec, 1., 4) +class TestCategoricalALEHighCardinality(unittest.TestCase): + """Test categorical ALE with high-cardinality features (issue #86).""" + + def setUp(self): + import pandas as pd + self.y = pd.DataFrame({"target": range(0, 100)}) + self.X = pd.DataFrame({ + "continuous": range(0, 100), + "low_card": pd.Categorical([1.1]*25 + [1.2]*25 + [1.3]*25 + [1.4]*25), + "high_card": pd.Categorical(np.linspace(0, 10, num=100)), + "med_card": pd.Categorical(list(np.linspace(0, 10, num=50)) * 2), + }) + self.rf = RandomForestRegressor(n_estimators=10, random_state=42) + self.rf.fit(self.X, self.y) + + def test_high_cardinality_categorical_ale(self): + """High cardinality categorical feature should not crash (issue #86).""" + explainer = skexplain.ExplainToolkit( + estimators=("rf", self.rf), X=self.X, y=self.y, + ) + ale = explainer.ale(features="high_card", n_bootstrap=2, n_bins=5) + self.assertIn("high_card__rf__ale", ale.data_vars) + + def test_medium_cardinality_categorical_ale(self): + """Medium cardinality categorical with duplicates should work.""" + explainer = skexplain.ExplainToolkit( + estimators=("rf", self.rf), X=self.X, y=self.y, + ) + ale = explainer.ale(features="med_card", n_bootstrap=2, n_bins=5) + self.assertIn("med_card__rf__ale", ale.data_vars) + + def test_low_cardinality_categorical_ale(self): + """Low cardinality categorical should work as before.""" + explainer = skexplain.ExplainToolkit( + estimators=("rf", self.rf), X=self.X, y=self.y, + ) + ale = explainer.ale(features="low_card", n_bootstrap=2, n_bins=5) + self.assertIn("low_card__rf__ale", ale.data_vars) + # Should have shape (n_bootstrap, n_categories) + self.assertEqual(ale["low_card__rf__ale"].shape[0], 2) + self.assertEqual(ale["low_card__rf__ale"].shape[1], 4) + + if __name__ == "__main__": unittest.main()