Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions skexplain/plot/base_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,34 @@ def __init__(self, BASE_FONT_SIZE=12, seaborn_kws=None):
plt.rc("legend", fontsize=self.FONT_SIZES["teensie"]) # legend fontsize
plt.rc("figure", titlesize=self.FONT_SIZES["big"]) # fontsize of the figure title

def get_fig_props(self, n_panels, **kwargs):
"""Determine appropriate figure properties"""
width_slope = 0.875
height_slope = 0.45
intercept = 3.0 - width_slope
figsize = (
min((n_panels * width_slope) + intercept, 19),
min((n_panels * height_slope) + intercept, 12),
)
# Golden ratio for aesthetically pleasing proportions
GOLDEN_RATIO = 1.618

wspace = (-0.03 * n_panels) + 0.85
hspace = (0.0175 * n_panels) + 0.3
def get_fig_props(self, n_panels, **kwargs):
"""Determine appropriate figure properties using the golden ratio.

Each panel targets a golden-ratio aspect (width slightly > height).
Total figure size scales with panel count while maintaining
pleasing proportions.
"""
n_columns = kwargs.get("n_columns", 3)
n_columns = min(n_columns, n_panels)
n_rows = max(1, int(np.ceil(n_panels / n_columns)))

# Per-panel dimensions based on golden ratio
panel_width = 3.5
panel_height = panel_width / self.GOLDEN_RATIO # ~2.16

# Total figure size with padding for axis labels and legend
total_width = max(6, min(n_columns * panel_width + 1.0, 19))
total_height = max(4.5, min(n_rows * (panel_height + 1.0) + 1.5, 14))

figsize = (total_width, total_height)

# Spacing scales inversely with panel count
wspace = max(0.3, 0.85 - 0.03 * n_panels)
hspace = max(0.3, 0.30 + 0.02 * n_panels)

wspace = wspace + 0.25 if n_columns > 3 else wspace

kwargs["figsize"] = kwargs.get("figsize", figsize)
Expand Down Expand Up @@ -424,6 +438,7 @@ def set_axis_label(self, ax, xaxis_label=None, yaxis_label=None, **kwargs):
physical units)
"""
fontsize = kwargs.get("fontsize", self.FONT_SIZES["tiny"])
fontweight = kwargs.get("fontweight", "normal")
if xaxis_label is not None:
xaxis_label_pretty = self.display_feature_names.get(xaxis_label, xaxis_label)
units = self.display_units.get(xaxis_label, "")
Expand All @@ -432,7 +447,7 @@ def set_axis_label(self, ax, xaxis_label=None, yaxis_label=None, **kwargs):
else:
xaxis_label_with_units = f"{xaxis_label_pretty} ({units})"

ax.set_xlabel(xaxis_label_with_units, fontsize=fontsize)
ax.set_xlabel(xaxis_label_with_units, fontsize=fontsize, fontweight=fontweight)

if yaxis_label is not None:
yaxis_label_pretty = self.display_feature_names.get(yaxis_label, yaxis_label)
Expand All @@ -442,16 +457,13 @@ def set_axis_label(self, ax, xaxis_label=None, yaxis_label=None, **kwargs):
else:
yaxis_label_with_units = f"{yaxis_label_pretty} ({units})"

ax.set_ylabel(yaxis_label_with_units, fontsize=fontsize)
ax.set_ylabel(yaxis_label_with_units, fontsize=fontsize, fontweight=fontweight)

def set_legend(self, n_panels, fig, ax, major_ax=None, **kwargs):
"""
Set a single legend on the bottom of a figure
for a set of subplots.
Set a single legend at the bottom of a figure,
outside the subplot area so it never overlaps panels.
"""
if major_ax is None:
major_ax = self.set_major_axis_labels(fig)

fontsize = kwargs.get("fontsize", "medium")
ncol = kwargs.get("ncol", 3)
handles = kwargs.get("handles", None)
Expand All @@ -463,28 +475,38 @@ def set_legend(self, n_panels, fig, ax, major_ax=None, **kwargs):
if labels is None:
_, labels = ax.get_legend_handles_labels()

if n_panels > 3:
bbox_to_anchor = (0.5, -0.35)
bbox_to_anchor = kwargs.get("bbox_to_anchor", None)

if bbox_to_anchor is not None:
# User override — use major_ax approach for backward compat
if major_ax is None:
major_ax = self.set_major_axis_labels(fig)
major_ax.legend(
handles, labels,
loc="lower center",
bbox_to_anchor=bbox_to_anchor,
fancybox=True, shadow=True,
ncol=ncol, fontsize=fontsize,
)
else:
bbox_to_anchor = (0.5, -0.5)

bbox_to_anchor = kwargs.get("bbox_to_anchor", bbox_to_anchor)

# Shrink current axis's height by 10% on the bottom
box = major_ax.get_position()
major_ax.set_position([box.x0, box.y0 + box.height * 0.1, box.width, box.height * 0.9])

# Put a legend below current axis
major_ax.legend(
handles,
labels,
loc="lower center",
bbox_to_anchor=bbox_to_anchor,
fancybox=True,
shadow=True,
ncol=ncol,
fontsize=fontsize,
)
# Render layout first so we can query actual subplot positions
fig.canvas.draw()
# Find the lowest subplot bottom edge
bottoms = [ax_i.get_position().y0 for ax_i in fig.get_axes()
if ax_i.get_visible() and ax_i.get_label() != '<colorbar>']
if bottoms:
lowest = min(bottoms)
else:
lowest = 0.1
# Place legend centered just below the lowest subplot
legend_y = max(0.0, lowest - 0.08)
fig.legend(
handles, labels,
loc="upper center",
bbox_to_anchor=(0.5, legend_y),
fancybox=True, shadow=True,
ncol=ncol, fontsize=fontsize,
)

def set_minor_ticks(self, ax):
"""
Expand Down
6 changes: 5 additions & 1 deletion skexplain/plot/plot_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ def plot_contours(
n_columns = len(estimator_names)

if n_panels == 1:
figsize = (6, 3)
figsize = (6, 5)
fontsize = 8
elif n_panels == 2:
figsize = (10, 5)
fontsize = 9
else:
figsize = (10, 8)
fontsize = 10
Expand Down Expand Up @@ -353,6 +356,7 @@ def plot_contours(
xaxis_label=feature_set[0],
yaxis_label=feature_set[1],
fontsize=fontsize,
fontweight="bold",
)
# Add a colorbar
if (
Expand Down
2 changes: 1 addition & 1 deletion skexplain/plot/plot_interpret_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def plot_1d_curve(
if n_panels < 10:
self.set_minor_ticks(lineplt_ax)

self.set_axis_label(lineplt_ax, xaxis_label="".join(feature))
self.set_axis_label(lineplt_ax, xaxis_label="".join(feature), fontweight="bold")
lineplt_ax.axhline(y=0.0, color="k", alpha=0.8, linewidth=0.8, linestyle="dashed")

# nticks = 5 if n_panels < 10 else 3
Expand Down
38 changes: 28 additions & 10 deletions skexplain/plot/plot_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def plot_variable_importance(
rho_threshold = kwargs.get("rho_threshold", 0.8)
plot_reference_score = kwargs.get("plot_reference_score", True)
plot_error = kwargs.get("plot_error", True)
show_method_subtitle = kwargs.get("show_method_subtitle", True)

only_one_method = all([m[0] == panels[0][0] for m in panels])
only_one_estimator = all([m[1] == panels[0][1] for m in panels])
Expand Down Expand Up @@ -234,17 +235,33 @@ def plot_variable_importance(
)

if plot_reference_score:
ref_label_fontsize = self.FONT_SIZES["teensie"]
# Small nudge to the right so text doesn't sit on the dashed line
x_range = np.ptp(scores_to_plot) if len(scores_to_plot) > 0 else 1.0
nudge = max(0.01, x_range * 0.02)
if "forward" in method:
ax.axvline(
results[f"all_permuted_score__{estimator_name}"].mean(), color="k", ls=":"
ref_score = results[f"all_permuted_score__{estimator_name}"].mean()
ax.axvline(ref_score, color="k", ls=":", alpha=0.7, zorder=1)
ax.text(
ref_score + nudge, len(scores_to_plot) / 2, "Original Score",
fontsize=ref_label_fontsize,
va="center", ha="left", color="0.3",
rotation=90,
)
elif "backward" in method:
ax.axvline(
results[f"original_score__{estimator_name}"].mean(), color="k", ls="--"
ref_score = results[f"original_score__{estimator_name}"].mean()
ax.axvline(ref_score, color="k", ls="--", alpha=0.7, zorder=1)
ax.text(
ref_score + nudge, len(scores_to_plot) / 2, "Original Score",
fontsize=ref_label_fontsize,
va="center", ha="left", color="0.3",
rotation=90,
)

# Despine
# Despine — keep bottom spine and ticks for x-axis clarity
self.despine_plt(ax)
ax.spines["bottom"].set_visible(True)
ax.tick_params(axis="x", which="both", length=3)

elinewidth = 0.9 if n_panels <= 3 else 0.5

Expand Down Expand Up @@ -333,11 +350,12 @@ def plot_variable_importance(
else:
self.set_n_ticks(ax, option="x")

xlabel = (
self.DISPLAY_NAMES_DICT.get(method, method)
if (only_one_method and xlabels is None)
else ""
)
if show_method_subtitle and only_one_method and xlabels is None:
xlabel = self.DISPLAY_NAMES_DICT.get(method, method)
elif xlabels is not None:
xlabel = ""
else:
xlabel = ""

major_ax = self.set_major_axis_labels(
fig,
Expand Down
14 changes: 11 additions & 3 deletions skexplain/plot/plot_sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@


def sobol_plot(
results, est_name=None, ax=None, display_feature_names=None, n_features=None, kind="bar"
results, est_name=None, ax=None, display_feature_names=None, n_features=None,
kind="bar", fontsize=None,
):
"""Plot Sobol sensitivity indices (1st order and higher order) as a stacked bar or barh chart.

Expand Down Expand Up @@ -59,12 +60,19 @@ def sobol_plot(

ax = df_result.plot(ax=ax, x="variable", kind=kind, stacked=True, rot=rot)

# Scale tick and label font sizes
if fontsize is None:
fontsize = 11
tick_fontsize = max(8, fontsize - 2)

if kind == "bar":
ax.set_xlabel("")
ax.set_ylabel("Total Sobol Index\n(1st order + higher order)")
ax.set_ylabel("Total Sobol Index\n(1st order + higher order)", fontsize=fontsize)
else:
ax.set_ylabel("")
ax.set_xlabel("Total Sobol Index\n(1st order + higher order)")
ax.set_xlabel("Total Sobol Index\n(1st order + higher order)", fontsize=fontsize)
ax.invert_yaxis()

ax.tick_params(axis="both", labelsize=tick_fontsize)

return ax
Loading