From baa00d997bde13d849454c7d19410001a92c419b Mon Sep 17 00:00:00 2001 From: monte-flora Date: Wed, 1 Apr 2026 18:41:09 +0000 Subject: [PATCH] Improve plot aesthetics: golden ratio, legends, labels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use golden ratio (1.618) for auto-computed figure proportions in get_fig_props — panels are 3.5" wide by ~2.16" tall with padding - Adaptive legend placement: query actual subplot positions and place legend just below the lowest panel (no more overlap) - Permutation importance: add rotated "Original Score" label along reference dashed line, nudged right proportionally to data range; optional method subtitle via show_method_subtitle kwarg; visible x-axis bottom spine and tick marks - ALE/PD 1D curves: bold x-axis labels via fontweight support in set_axis_label - ALE/PD 2D plots: improved 2-panel figsize (10x5), bold axis labels - Sobol plot: explicit fontsize parameter with scaled tick labels - Minimum figure dimensions (6" wide, 4.5" tall) prevent squished aspect ratios on small panel counts Co-Authored-By: Claude Opus 4.6 (1M context) --- skexplain/plot/base_plotting.py | 100 +++++++++++------- skexplain/plot/plot_2D.py | 6 +- skexplain/plot/plot_interpret_curves.py | 2 +- skexplain/plot/plot_permutation_importance.py | 38 +++++-- skexplain/plot/plot_sobol.py | 14 ++- 5 files changed, 106 insertions(+), 54 deletions(-) diff --git a/skexplain/plot/base_plotting.py b/skexplain/plot/base_plotting.py index 70450e8..328b1b9 100644 --- a/skexplain/plot/base_plotting.py +++ b/skexplain/plot/base_plotting.py @@ -59,20 +59,34 @@ def __init__(self, BASE_FONT_SIZE=12, seaborn_kws=None): plt.rc("legend", fontsize=self.FONT_SIZES["teensie"]) # legend fontsize plt.rc("figure", titlesize=self.FONT_SIZES["big"]) # fontsize of the figure title - def get_fig_props(self, n_panels, **kwargs): - """Determine appropriate figure properties""" - width_slope = 0.875 - height_slope = 0.45 - intercept = 3.0 - width_slope - figsize = ( - min((n_panels * width_slope) + intercept, 19), - min((n_panels * height_slope) + intercept, 12), - ) + # Golden ratio for aesthetically pleasing proportions + GOLDEN_RATIO = 1.618 - wspace = (-0.03 * n_panels) + 0.85 - hspace = (0.0175 * n_panels) + 0.3 + def get_fig_props(self, n_panels, **kwargs): + """Determine appropriate figure properties using the golden ratio. + Each panel targets a golden-ratio aspect (width slightly > height). + Total figure size scales with panel count while maintaining + pleasing proportions. + """ n_columns = kwargs.get("n_columns", 3) + n_columns = min(n_columns, n_panels) + n_rows = max(1, int(np.ceil(n_panels / n_columns))) + + # Per-panel dimensions based on golden ratio + panel_width = 3.5 + panel_height = panel_width / self.GOLDEN_RATIO # ~2.16 + + # Total figure size with padding for axis labels and legend + total_width = max(6, min(n_columns * panel_width + 1.0, 19)) + total_height = max(4.5, min(n_rows * (panel_height + 1.0) + 1.5, 14)) + + figsize = (total_width, total_height) + + # Spacing scales inversely with panel count + wspace = max(0.3, 0.85 - 0.03 * n_panels) + hspace = max(0.3, 0.30 + 0.02 * n_panels) + wspace = wspace + 0.25 if n_columns > 3 else wspace kwargs["figsize"] = kwargs.get("figsize", figsize) @@ -424,6 +438,7 @@ def set_axis_label(self, ax, xaxis_label=None, yaxis_label=None, **kwargs): physical units) """ fontsize = kwargs.get("fontsize", self.FONT_SIZES["tiny"]) + fontweight = kwargs.get("fontweight", "normal") if xaxis_label is not None: xaxis_label_pretty = self.display_feature_names.get(xaxis_label, xaxis_label) units = self.display_units.get(xaxis_label, "") @@ -432,7 +447,7 @@ def set_axis_label(self, ax, xaxis_label=None, yaxis_label=None, **kwargs): else: xaxis_label_with_units = f"{xaxis_label_pretty} ({units})" - ax.set_xlabel(xaxis_label_with_units, fontsize=fontsize) + ax.set_xlabel(xaxis_label_with_units, fontsize=fontsize, fontweight=fontweight) if yaxis_label is not None: yaxis_label_pretty = self.display_feature_names.get(yaxis_label, yaxis_label) @@ -442,16 +457,13 @@ def set_axis_label(self, ax, xaxis_label=None, yaxis_label=None, **kwargs): else: yaxis_label_with_units = f"{yaxis_label_pretty} ({units})" - ax.set_ylabel(yaxis_label_with_units, fontsize=fontsize) + ax.set_ylabel(yaxis_label_with_units, fontsize=fontsize, fontweight=fontweight) def set_legend(self, n_panels, fig, ax, major_ax=None, **kwargs): """ - Set a single legend on the bottom of a figure - for a set of subplots. + Set a single legend at the bottom of a figure, + outside the subplot area so it never overlaps panels. """ - if major_ax is None: - major_ax = self.set_major_axis_labels(fig) - fontsize = kwargs.get("fontsize", "medium") ncol = kwargs.get("ncol", 3) handles = kwargs.get("handles", None) @@ -463,28 +475,38 @@ def set_legend(self, n_panels, fig, ax, major_ax=None, **kwargs): if labels is None: _, labels = ax.get_legend_handles_labels() - if n_panels > 3: - bbox_to_anchor = (0.5, -0.35) + bbox_to_anchor = kwargs.get("bbox_to_anchor", None) + + if bbox_to_anchor is not None: + # User override — use major_ax approach for backward compat + if major_ax is None: + major_ax = self.set_major_axis_labels(fig) + major_ax.legend( + handles, labels, + loc="lower center", + bbox_to_anchor=bbox_to_anchor, + fancybox=True, shadow=True, + ncol=ncol, fontsize=fontsize, + ) else: - bbox_to_anchor = (0.5, -0.5) - - bbox_to_anchor = kwargs.get("bbox_to_anchor", bbox_to_anchor) - - # Shrink current axis's height by 10% on the bottom - box = major_ax.get_position() - major_ax.set_position([box.x0, box.y0 + box.height * 0.1, box.width, box.height * 0.9]) - - # Put a legend below current axis - major_ax.legend( - handles, - labels, - loc="lower center", - bbox_to_anchor=bbox_to_anchor, - fancybox=True, - shadow=True, - ncol=ncol, - fontsize=fontsize, - ) + # Render layout first so we can query actual subplot positions + fig.canvas.draw() + # Find the lowest subplot bottom edge + bottoms = [ax_i.get_position().y0 for ax_i in fig.get_axes() + if ax_i.get_visible() and ax_i.get_label() != ''] + if bottoms: + lowest = min(bottoms) + else: + lowest = 0.1 + # Place legend centered just below the lowest subplot + legend_y = max(0.0, lowest - 0.08) + fig.legend( + handles, labels, + loc="upper center", + bbox_to_anchor=(0.5, legend_y), + fancybox=True, shadow=True, + ncol=ncol, fontsize=fontsize, + ) def set_minor_ticks(self, ax): """ diff --git a/skexplain/plot/plot_2D.py b/skexplain/plot/plot_2D.py index ebbc110..b514c91 100644 --- a/skexplain/plot/plot_2D.py +++ b/skexplain/plot/plot_2D.py @@ -169,8 +169,11 @@ def plot_contours( n_columns = len(estimator_names) if n_panels == 1: - figsize = (6, 3) + figsize = (6, 5) fontsize = 8 + elif n_panels == 2: + figsize = (10, 5) + fontsize = 9 else: figsize = (10, 8) fontsize = 10 @@ -353,6 +356,7 @@ def plot_contours( xaxis_label=feature_set[0], yaxis_label=feature_set[1], fontsize=fontsize, + fontweight="bold", ) # Add a colorbar if ( diff --git a/skexplain/plot/plot_interpret_curves.py b/skexplain/plot/plot_interpret_curves.py index 8c1cbe6..77a0a71 100644 --- a/skexplain/plot/plot_interpret_curves.py +++ b/skexplain/plot/plot_interpret_curves.py @@ -174,7 +174,7 @@ def plot_1d_curve( if n_panels < 10: self.set_minor_ticks(lineplt_ax) - self.set_axis_label(lineplt_ax, xaxis_label="".join(feature)) + self.set_axis_label(lineplt_ax, xaxis_label="".join(feature), fontweight="bold") lineplt_ax.axhline(y=0.0, color="k", alpha=0.8, linewidth=0.8, linestyle="dashed") # nticks = 5 if n_panels < 10 else 3 diff --git a/skexplain/plot/plot_permutation_importance.py b/skexplain/plot/plot_permutation_importance.py index f997773..e391b15 100644 --- a/skexplain/plot/plot_permutation_importance.py +++ b/skexplain/plot/plot_permutation_importance.py @@ -144,6 +144,7 @@ def plot_variable_importance( rho_threshold = kwargs.get("rho_threshold", 0.8) plot_reference_score = kwargs.get("plot_reference_score", True) plot_error = kwargs.get("plot_error", True) + show_method_subtitle = kwargs.get("show_method_subtitle", True) only_one_method = all([m[0] == panels[0][0] for m in panels]) only_one_estimator = all([m[1] == panels[0][1] for m in panels]) @@ -234,17 +235,33 @@ def plot_variable_importance( ) if plot_reference_score: + ref_label_fontsize = self.FONT_SIZES["teensie"] + # Small nudge to the right so text doesn't sit on the dashed line + x_range = np.ptp(scores_to_plot) if len(scores_to_plot) > 0 else 1.0 + nudge = max(0.01, x_range * 0.02) if "forward" in method: - ax.axvline( - results[f"all_permuted_score__{estimator_name}"].mean(), color="k", ls=":" + ref_score = results[f"all_permuted_score__{estimator_name}"].mean() + ax.axvline(ref_score, color="k", ls=":", alpha=0.7, zorder=1) + ax.text( + ref_score + nudge, len(scores_to_plot) / 2, "Original Score", + fontsize=ref_label_fontsize, + va="center", ha="left", color="0.3", + rotation=90, ) elif "backward" in method: - ax.axvline( - results[f"original_score__{estimator_name}"].mean(), color="k", ls="--" + ref_score = results[f"original_score__{estimator_name}"].mean() + ax.axvline(ref_score, color="k", ls="--", alpha=0.7, zorder=1) + ax.text( + ref_score + nudge, len(scores_to_plot) / 2, "Original Score", + fontsize=ref_label_fontsize, + va="center", ha="left", color="0.3", + rotation=90, ) - # Despine + # Despine — keep bottom spine and ticks for x-axis clarity self.despine_plt(ax) + ax.spines["bottom"].set_visible(True) + ax.tick_params(axis="x", which="both", length=3) elinewidth = 0.9 if n_panels <= 3 else 0.5 @@ -333,11 +350,12 @@ def plot_variable_importance( else: self.set_n_ticks(ax, option="x") - xlabel = ( - self.DISPLAY_NAMES_DICT.get(method, method) - if (only_one_method and xlabels is None) - else "" - ) + if show_method_subtitle and only_one_method and xlabels is None: + xlabel = self.DISPLAY_NAMES_DICT.get(method, method) + elif xlabels is not None: + xlabel = "" + else: + xlabel = "" major_ax = self.set_major_axis_labels( fig, diff --git a/skexplain/plot/plot_sobol.py b/skexplain/plot/plot_sobol.py index 9c934d4..44f66c7 100644 --- a/skexplain/plot/plot_sobol.py +++ b/skexplain/plot/plot_sobol.py @@ -3,7 +3,8 @@ def sobol_plot( - results, est_name=None, ax=None, display_feature_names=None, n_features=None, kind="bar" + results, est_name=None, ax=None, display_feature_names=None, n_features=None, + kind="bar", fontsize=None, ): """Plot Sobol sensitivity indices (1st order and higher order) as a stacked bar or barh chart. @@ -59,12 +60,19 @@ def sobol_plot( ax = df_result.plot(ax=ax, x="variable", kind=kind, stacked=True, rot=rot) + # Scale tick and label font sizes + if fontsize is None: + fontsize = 11 + tick_fontsize = max(8, fontsize - 2) + if kind == "bar": ax.set_xlabel("") - ax.set_ylabel("Total Sobol Index\n(1st order + higher order)") + ax.set_ylabel("Total Sobol Index\n(1st order + higher order)", fontsize=fontsize) else: ax.set_ylabel("") - ax.set_xlabel("Total Sobol Index\n(1st order + higher order)") + ax.set_xlabel("Total Sobol Index\n(1st order + higher order)", fontsize=fontsize) ax.invert_yaxis() + ax.tick_params(axis="both", labelsize=tick_fontsize) + return ax