Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 320 additions & 0 deletions docs/interpret/ebm-internals-quantile-regression.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# EBM Internals - Quantile Regression\n",
"\n",
"This notebook covers quantile regression using pinball loss with Explainable Boosting Machines. For standard regression internals, see [Part 1](./ebm-internals-regression.ipynb). For classification, see [Part 2](./ebm-internals-classification.ipynb). For multiclass, see [Part 3](./ebm-internals-multiclass.ipynb).\n",
"\n",
"Standard regression models (e.g. with RMSE) predict the conditional mean of the target. Quantile regression instead predicts a specific quantile (e.g. median, 10th percentile, 90th percentile). This is useful for:\n",
"\n",
"- **Prediction intervals**: Fit models at the 10th and 90th percentiles to get an 80% prediction interval.\n",
"- **Asymmetric risk**: When over-predicting and under-predicting have different costs.\n",
"- **Robustness**: Median regression (alpha=0.5) is more robust to outliers than mean regression.\n",
"\n",
"EBMs support quantile regression via the `\"quantile\"` objective, which uses the pinball loss (also called the quantile loss). The pinball loss for quantile alpha is:\n",
"\n",
"$$L(y, \\hat{y}) = \\begin{cases} \\alpha \\cdot (y - \\hat{y}) & \\text{if } y \\geq \\hat{y} \\\\ (1 - \\alpha) \\cdot (\\hat{y} - y) & \\text{if } y < \\hat{y} \\end{cases}$$\n",
"\n",
"This loss penalizes under-predictions by a factor of alpha and over-predictions by a factor of (1 - alpha), causing the model to learn the alpha-quantile of the conditional distribution."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# boilerplate\n",
"from interpret import show\n",
"from interpret.glassbox import ExplainableBoostingRegressor\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Median Regression (alpha=0.5)\n",
"\n",
"Let's start with median regression. With alpha=0.5, the pinball loss penalizes over- and under-predictions equally, so the model learns the conditional median rather than the conditional mean."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make a dataset composed of a nominal categorical, and a continuous feature\n",
"X = [[\"Peru\", 7.0], [\"Fiji\", 8.0], [\"Peru\", 9.0]]\n",
"y = [450.0, 550.0, 350.0]\n",
"\n",
"# Fit a quantile EBM for the median (alpha=0.5)\n",
"# Eliminate the validation set to handle the small dataset\n",
"ebm_median = ExplainableBoostingRegressor(\n",
" objective=\"quantile:alpha=0.5\",\n",
" interactions=0,\n",
" validation_size=0, outer_bags=1, min_samples_leaf=1, min_hessian=1e-9)\n",
"ebm_median.fit(X, y)\n",
"show(ebm_median.explain_global())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model structure is identical to a standard regression EBM: an intercept plus additive score contributions from each feature, looked up via binning. The only difference is the loss function used during training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Intercept:\", ebm_median.intercept_)\n",
"print(\"Feature types:\", ebm_median.feature_types_in_)\n",
"print(\"Bins:\", ebm_median.bins_)\n",
"print(\"Categorical scores:\", ebm_median.term_scores_[0])\n",
"print(\"Continuous scores:\", ebm_median.term_scores_[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predictions are computed identically to standard regression EBMs: start from the intercept and add lookup table scores for each feature."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Median predictions:\", ebm_median.predict(X))\n",
"print(\"Original y values: \", y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction Intervals\n",
"\n",
"A key use case for quantile regression is constructing prediction intervals. By fitting separate models at different quantiles, we can estimate the range within which future observations are likely to fall.\n",
"\n",
"Let's use a larger, noisier dataset to demonstrate this. We'll fit models at the 10th, 50th, and 90th percentiles to get an 80% prediction interval."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import make_regression\n",
"\n",
"X_train, y_train = make_regression(\n",
" n_samples=1000, n_features=5, noise=20.0, random_state=42)\n",
"\n",
"X_test, y_test = make_regression(\n",
" n_samples=200, n_features=5, noise=20.0, random_state=123)\n",
"\n",
"# Fit quantile models at the 10th, 50th, and 90th percentiles\n",
"ebm_10 = ExplainableBoostingRegressor(objective=\"quantile:alpha=0.1\")\n",
"ebm_50 = ExplainableBoostingRegressor(objective=\"quantile:alpha=0.5\")\n",
"ebm_90 = ExplainableBoostingRegressor(objective=\"quantile:alpha=0.9\")\n",
"\n",
"ebm_10.fit(X_train, y_train)\n",
"ebm_50.fit(X_train, y_train)\n",
"ebm_90.fit(X_train, y_train)\n",
"\n",
"pred_10 = ebm_10.predict(X_test)\n",
"pred_50 = ebm_50.predict(X_test)\n",
"pred_90 = ebm_90.predict(X_test)\n",
"\n",
"print(\"First 5 test samples:\")\n",
"print(\" 10th percentile:\", np.round(pred_10[:5], 2))\n",
"print(\" 50th percentile:\", np.round(pred_50[:5], 2))\n",
"print(\" 90th percentile:\", np.round(pred_90[:5], 2))\n",
"print(\" Actual y: \", np.round(y_test[:5], 2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Verify quantile ordering: q10 < q50 < q90 for most predictions\n",
"print(\"Fraction where q10 < q50:\", np.mean(pred_10 < pred_50))\n",
"print(\"Fraction where q50 < q90:\", np.mean(pred_50 < pred_90))\n",
"\n",
"# Check empirical coverage of the 80% prediction interval [q10, q90]\n",
"coverage = np.mean((y_test >= pred_10) & (y_test <= pred_90))\n",
"print(f\"Empirical coverage of [q10, q90] interval: {coverage:.1%} (target: ~80%)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing the Prediction Interval\n",
"\n",
"Let's visualize the prediction interval on a sorted subset of test samples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import matplotlib\n",
" import matplotlib.pyplot as plt\n",
"\n",
" # Sort by predicted median for a cleaner plot\n",
" sort_idx = np.argsort(pred_50)\n",
" x_axis = np.arange(len(sort_idx))\n",
"\n",
" fig, ax = plt.subplots(figsize=(12, 5))\n",
" ax.fill_between(x_axis, pred_10[sort_idx], pred_90[sort_idx],\n",
" alpha=0.3, label=\"80% prediction interval (q10-q90)\")\n",
" ax.plot(x_axis, pred_50[sort_idx], label=\"Median prediction (q50)\", linewidth=1.5)\n",
" ax.scatter(x_axis, y_test[sort_idx], s=8, color=\"red\", alpha=0.6, label=\"Actual values\")\n",
" ax.set_xlabel(\"Test samples (sorted by predicted median)\")\n",
" ax.set_ylabel(\"Target value\")\n",
" ax.set_title(\"EBM Quantile Regression: 80% Prediction Interval\")\n",
" ax.legend()\n",
" plt.tight_layout()\n",
" plt.show()\n",
"except ImportError:\n",
" print(\"matplotlib not installed, skipping plot\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interpretability\n",
"\n",
"One of the key advantages of quantile EBMs is that they remain fully interpretable. The global explanations show how each feature contributes to the predicted quantile, and local explanations show the additive score breakdown for individual predictions.\n",
"\n",
"Let's compare the shape functions for the same feature across different quantiles."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Show global explanations for each quantile model\n",
"print(\"=== 10th Percentile Model ===\")\n",
"show(ebm_10.explain_global())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"=== 90th Percentile Model ===\")\n",
"show(ebm_90.explain_global())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Local explanation for a single test sample\n",
"show(ebm_50.explain_local(X_test[:5], y_test[:5]), 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Making Predictions Manually\n",
"\n",
"Just like standard regression EBMs, quantile EBM predictions are computed by summing the intercept with lookup table scores for each feature. The prediction logic is identical; only the training loss differs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Use the small dataset to demonstrate manual predictions\n",
"X_small = [[\"Peru\", 7.0], [\"Fiji\", 8.0], [\"Peru\", 9.0]]\n",
"y_small = [450.0, 550.0, 350.0]\n",
"\n",
"sample_scores = []\n",
"for sample in X_small:\n",
" score = ebm_median.intercept_\n",
" print(\"intercept: \" + str(score))\n",
"\n",
" for feature_idx, feature_val in enumerate(sample):\n",
" bins = ebm_median.bins_[feature_idx][0]\n",
" if isinstance(bins, dict):\n",
" bin_idx = bins[feature_val]\n",
" else:\n",
" bin_idx = np.digitize(feature_val, bins) + 1\n",
"\n",
" local_score = ebm_median.term_scores_[feature_idx][bin_idx]\n",
" print(ebm_median.feature_names_in_[feature_idx] + \": \" + str(local_score))\n",
" score += local_score\n",
" sample_scores.append(score)\n",
" print()\n",
"\n",
"print(\"PREDICTIONS (manual):\")\n",
"print(np.array(sample_scores))\n",
"print(\"PREDICTIONS (ebm.predict):\")\n",
"print(ebm_median.predict(X_small))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"- Use `objective=\"quantile:alpha=0.5\"` for median regression, or any alpha in (0, 1) for other quantiles.\n",
"- The prediction mechanism is identical to standard regression EBMs (intercept + additive score lookups). Only the training loss function changes.\n",
"- Fitting multiple quantile models (e.g. alpha=0.1 and alpha=0.9) provides interpretable prediction intervals.\n",
"- All EBM interpretability tools (global/local explanations) work with quantile models.\n",
"- For the complete prediction code that handles interactions, missing values, and all model types, see [Part 3](./ebm-internals-multiclass.ipynb)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "3.10.13",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions docs/interpret/ebm-internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ This section is divided into 3 parts that build upon each other:
[Part 1](./ebm-internals-regression.ipynb) Covers regression for pure GAM models (no interactions).
[Part 2](./ebm-internals-classification.ipynb) Covers binary classification with interactions, ordinals, and missing values.
[Part 3](./ebm-internals-multiclass.ipynb) Covers multiclass, and unseen values.
[Quantile Regression](./ebm-internals-quantile-regression.ipynb) Covers quantile regression with pinball loss and prediction intervals.
3 changes: 2 additions & 1 deletion python/interpret-core/interpret/glassbox/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3940,7 +3940,8 @@ class EBMRegressor(EBMRegressorMixin, EBMModel):
objective : str, default="rmse"
The objective to optimize. Options include: "rmse",
"poisson_deviance", "tweedie_deviance:variance_power=1.5", "gamma_deviance",
"pseudo_huber:delta=1.0", "rmse_log" (rmse with a log link function)
"pseudo_huber:delta=1.0", "rmse_log" (rmse with a log link function),
"quantile:alpha=0.5" (quantile regression with pinball loss)
n_jobs : int, default=-2
Number of jobs to run in parallel. Negative integers are interpreted as following joblib's formula
(n_cpus + 1 + n_jobs), just like scikit-learn. Eg: -2 means using all threads except 1.
Expand Down
Loading
Loading