From 4582360c6a02d7577eedeaad33643bdca13b6513 Mon Sep 17 00:00:00 2001 From: vhertel Date: Wed, 10 Jun 2026 18:45:59 +0200 Subject: [PATCH 1/4] Add custom cmap support --- .../evaluate/src/weathergen/evaluate/plotting/plotter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 34c7a8cd3..a52e5c1ac 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -594,7 +594,7 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: ---------- map_kwargs : dict or None Raw keyword arguments from the caller. Known keys (``marker_size``, - ``scale_marker_size``, ``marker``, ``vmin``, ``vmax``, ``colormap``, + ``scale_marker_size``, ``marker``, ``vmin``, ``vmax``, ``colormap``,``colors``, ``use_datashader``, ``levels``, and HEALPix-related keys) are extracted; remaining keys are collected under ``"extra"``. stream : str or None @@ -610,6 +610,7 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: - marker (str) - vmin, vmax (float or None) - cmap (matplotlib.colors.Colormap) + - colors (list or None) - use_datashader (bool) - norm (matplotlib.colors.Normalize or BoundaryNorm) - add_healpix_grid (bool) and related healpix_* keys @@ -626,6 +627,7 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: "vmin": kw.pop("vmin", None), "vmax": kw.pop("vmax", None), "cmap": plt.get_cmap(kw.pop("colormap", "coolwarm")), + "colors": kw.pop("colors", None), "use_datashader": kw.pop("use_datashader", False), "levels": kw.pop("levels", None), # HEALPix grid @@ -909,6 +911,9 @@ def scatter_plot( if opts["vmax"] is None: opts["vmax"] = float(p_hi) + if isinstance(opts["colors"], oc.listconfig.ListConfig): + opts["cmap"] = mpl.colors.ListedColormap(list(opts["colors"])) + if isinstance(opts["levels"], oc.listconfig.ListConfig): opts["norm"] = mpl.colors.BoundaryNorm(opts["levels"], opts["cmap"].N, extend="both") elif self.log_colorbar and opts["vmin"] is not None and opts["vmin"] > 0: From 0c0c7d454153600fb18a799c1407a61dc76c1548 Mon Sep 17 00:00:00 2001 From: vhertel Date: Wed, 10 Jun 2026 20:55:16 +0200 Subject: [PATCH 2/4] Add glob support --- .../weathergen/evaluate/plotting/plotter.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index a52e5c1ac..4e2de577e 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import datetime +import fnmatch import logging import os import warnings @@ -574,7 +575,7 @@ def create_maps_per_sample( var, region, tag=tag, - map_kwargs=dict(map_kwargs.get(var, {})) | map_kwargs_global, + map_kwargs=self._match_glob_kwargs(map_kwargs, var) | map_kwargs_global, title=self.get_map_title(var, valid_time, da_t), ) plot_names.append(name) @@ -583,6 +584,39 @@ def create_maps_per_sample( return plot_names + # glob pattern resolution + @staticmethod + def _match_glob_kwargs(map_kwargs: dict | None, var: str) -> dict: + """Resolve variable-specific plotting options, supporting exact and glob keys. + + Keys under a stream section may be either: + - an exact channel name, e.g. "tp_imerg_0" + - a glob pattern containing one of ``* ? [ ]``, e.g. "tp_*" + + All matching glob patterns are merged first (in config order), then the + exact channel key is merged on top, so an exact key overrides glob + values on key conflicts while non-conflicting keys from both are kept. + """ + if map_kwargs is None: + return {} + + resolved: dict = {} + + # merge glob/pattern keys + for key, value in map_kwargs.items(): + if not isinstance(value, oc.DictConfig | dict): + continue + k = str(key) + if any(ch in k for ch in "*?[]") and fnmatch.fnmatch(var, k): + resolved |= dict(value) + + # override with exact key if present + exact = map_kwargs.get(var) + if isinstance(exact, oc.DictConfig | dict): + resolved |= dict(exact) + + return resolved + # map_kwargs parsing @staticmethod def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: From 48f040cc4bfddad495b9f5d5594c6be27acd26eb Mon Sep 17 00:00:00 2001 From: vhertel Date: Wed, 17 Jun 2026 17:11:17 +0200 Subject: [PATCH 3/4] Update range calculation with glob --- .../weathergen/evaluate/utils/array_utils.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/array_utils.py b/packages/evaluate/src/weathergen/evaluate/utils/array_utils.py index 34bd3965d..785070125 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/array_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/array_utils.py @@ -9,6 +9,8 @@ """Array / DataArray utility functions: range computation, coordinate helpers.""" +import fnmatch + import numpy as np import omegaconf as oc import xarray as xr @@ -88,21 +90,27 @@ def common_ranges( """ maps_config = global_plotting_opts_stream.copy() for var in plot_chs: - if var in maps_config: - if not isinstance(maps_config[var].get("vmax"), (int | float)): - list_max = calc_bounds(data_tars, data_preds, var, "max") - list_max = np.concatenate([arr.flatten() for arr in list_max]).tolist() - maps_config[var].update({"vmax": float(max(list_max))}) - if not isinstance(maps_config[var].get("vmin"), (int | float)): - list_min = calc_bounds(data_tars, data_preds, var, "min") - list_min = np.concatenate([arr.flatten() for arr in list_min]).tolist() - maps_config[var].update({"vmin": float(min(list_min))}) - else: + if var not in maps_config: + maps_config[var] = {} + # override empty bounds with matching glob bounds from config + for key, value in maps_config.items(): + if not isinstance(value, oc.DictConfig | dict): + continue + k = str(key) + if any(c in k for c in "*?[]") and fnmatch.fnmatch(var, k): + for bound in ("vmin", "vmax"): + if isinstance(value.get(bound), int | float): + maps_config[var].setdefault(bound, value[bound]) + # if vmax still missing, compute bound from data + if not isinstance(maps_config[var].get("vmax"), (int | float)): list_max = calc_bounds(data_tars, data_preds, var, "max") list_max = np.concatenate([arr.flatten() for arr in list_max]).tolist() + maps_config[var].update({"vmax": float(max(list_max))}) + # if vmin still missing, compute bound from data + if not isinstance(maps_config[var].get("vmin"), (int | float)): list_min = calc_bounds(data_tars, data_preds, var, "min") list_min = np.concatenate([arr.flatten() for arr in list_min]).tolist() - maps_config.update({var: {"vmax": float(max(list_max)), "vmin": float(min(list_min))}}) + maps_config[var].update({"vmin": float(min(list_min))}) return maps_config From 169aad829cc4bf4ce51b16a1022ccf124afcd2aa Mon Sep 17 00:00:00 2001 From: vhertel Date: Wed, 17 Jun 2026 17:49:14 +0200 Subject: [PATCH 4/4] Add colorbar scale options --- .../evaluate/plotting/plot_orchestration.py | 1 - .../weathergen/evaluate/plotting/plotter.py | 107 ++++++++++++++---- 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 2a44c1993..f59a709ab 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -402,7 +402,6 @@ def run_score_map_pipeline( "fig_size": cfg.get("fig_size", None), "animation_format": cfg.get("animation_format", "gif"), "fps": cfg.get("fps", 2), - "log_colorbar": cfg.get("log_colorbar", False), } output_basedir = str(reader.runplot_dir) run_id = reader.run_id diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 4e2de577e..192bf1327 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -138,7 +138,6 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.dpi_val = plotter_cfg.get("dpi_val") self.fig_size = plotter_cfg.get("fig_size") self.fps = plotter_cfg.get("fps") - self.log_colorbar = plotter_cfg.get("log_colorbar", False) self.regions = plotter_cfg.get("regions") self.log_x = plotter_cfg.get("log_x", False) self.log_y = plotter_cfg.get("log_y", False) @@ -518,8 +517,8 @@ def create_maps_per_sample( """ self.update_data_selection(select) - # copy global plotting options, not specific to any variable - map_kwargs_global = { + # copy stream plotting options, not specific to any variable + map_kwargs_stream = { key: value for key, value in (map_kwargs or {}).items() if not isinstance(value, oc.DictConfig) @@ -575,7 +574,7 @@ def create_maps_per_sample( var, region, tag=tag, - map_kwargs=self._match_glob_kwargs(map_kwargs, var) | map_kwargs_global, + map_kwargs=self._match_glob_kwargs(map_kwargs, var) | map_kwargs_stream, title=self.get_map_title(var, valid_time, da_t), ) plot_names.append(name) @@ -629,8 +628,8 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: map_kwargs : dict or None Raw keyword arguments from the caller. Known keys (``marker_size``, ``scale_marker_size``, ``marker``, ``vmin``, ``vmax``, ``colormap``,``colors``, - ``use_datashader``, ``levels``, and HEALPix-related keys) are extracted; - remaining keys are collected under ``"extra"``. + ``use_datashader``, ``levels``, ``colorbar_scale`` and HEALPix-related keys) are + extracted; remaining keys are collected under ``"extra"``. stream : str or None Stream name used to look up the default marker size when ``marker_size`` is not provided in *map_kwargs*. @@ -660,10 +659,11 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: "marker": kw.pop("marker", "o"), "vmin": kw.pop("vmin", None), "vmax": kw.pop("vmax", None), - "cmap": plt.get_cmap(kw.pop("colormap", "coolwarm")), + "cmap": kw.pop("colormap", "coolwarm"), "colors": kw.pop("colors", None), "use_datashader": kw.pop("use_datashader", False), "levels": kw.pop("levels", None), + "colorbar_scale": kw.pop("colorbar_scale", "linear"), # HEALPix grid "add_healpix_grid": kw.pop("add_healpix_grid", False), "healpix_nside": kw.pop("healpix_nside", 4), @@ -676,6 +676,82 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: parsed["extra"] = kw # remaining kwargs forwarded to scatter return parsed + @staticmethod + def _resolve_cmap(opts: dict, tag: str) -> mpl.colors.Colormap: + """Resolve the colormap for the plot. + + Parameters + ---------- + opts : dict + Parsed map kwargs from ``_parse_map_kwargs``. + tag : str + Plot tag (e.g. ``'targets'``, ``'preds'``, ``'bias'``). + + Returns + ------- + matplotlib.colors.Colormap + The resolved colormap. + """ + + # Bias maps always use coolwarm for visual consistency (overrides config) + if str(tag).startswith("bias"): + return plt.get_cmap("coolwarm") + # Explicit colors take precedence over colormap + elif isinstance(opts["colors"], oc.listconfig.ListConfig): + return mpl.colors.ListedColormap(list(opts["colors"])) + # Otherwise use the specified colormap (default "coolwarm") + else: + return plt.get_cmap(opts["cmap"]) + + @staticmethod + def _resolve_norm(opts: dict, tag: str) -> mpl.colors.Normalize: + """Resolve the colorbar scale and build color normalisation. + + Bias maps (``tag`` starting with ``"bias"``) are signed, so a non-linear + scale is coerced to ``"symlog"``. Explicit ``levels`` always take + precedence and yield a BoundaryNorm. + + Parameters + ---------- + opts : dict + Parsed map kwargs from ``_parse_map_kwargs``. + tag : str + Plot tag (e.g. ``'targets'``, ``'preds'``, ``'bias'``). + + Returns + ------- + matplotlib.colors.Normalize + LogNorm, SymLogNorm, BoundaryNorm or linear Normalize. + """ + + scale = opts["colorbar_scale"] + if scale not in {"linear", "log", "symlog"}: + _logger.warning("Unknown colorbar_scale=%r. Falling back to linear.", scale) + scale = "linear" + + # Bias maps are always signed, use symlog instead of plain log + if str(tag).startswith("bias") and scale != "linear": + scale = "symlog" + + vmin, vmax = opts["vmin"], opts["vmax"] + + # Explicit levels override continuous norm for preds and targets + if isinstance(opts["levels"], oc.listconfig.ListConfig) and not str(tag).startswith("bias"): + return mpl.colors.BoundaryNorm(opts["levels"], opts["cmap"].N, extend="both") + # Log only if vmin > 0; otherwise fall back to linear with warning + if scale == "log" and vmin is not None and vmin > 0: + return mpl.colors.LogNorm(vmin=vmin, vmax=vmax) + elif scale == "log" and vmin is not None and vmin <= 0: + _logger.warning( + "colorbar_scale='log' but vmin=%.3g <= 0; falling back to linear norm.", + vmin, + ) + # Symlog (default for bias maps with non-linear scale) + if scale == "symlog" and vmin is not None and vmax is not None: + vmax_abs = max(abs(float(vmin)), abs(float(vmax))) + return mpl.colors.SymLogNorm(linthresh=max(vmax_abs * 1e-3, 1e-8), vmin=vmin, vmax=vmax) + return mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) + # rendering backends @staticmethod def _render_datashader(ax, proj, data, norm, cmap, marker_size_base): @@ -945,20 +1021,9 @@ def scatter_plot( if opts["vmax"] is None: opts["vmax"] = float(p_hi) - if isinstance(opts["colors"], oc.listconfig.ListConfig): - opts["cmap"] = mpl.colors.ListedColormap(list(opts["colors"])) - - if isinstance(opts["levels"], oc.listconfig.ListConfig): - opts["norm"] = mpl.colors.BoundaryNorm(opts["levels"], opts["cmap"].N, extend="both") - elif self.log_colorbar and opts["vmin"] is not None and opts["vmin"] > 0: - opts["norm"] = mpl.colors.LogNorm(vmin=opts["vmin"], vmax=opts["vmax"]) - else: - if self.log_colorbar: - _logger.warning( - "log_colorbar=True but vmin=%.3g <= 0; falling back to linear norm.", - opts["vmin"], - ) - opts["norm"] = mpl.colors.Normalize(vmin=opts["vmin"], vmax=opts["vmax"], clip=False) + # resolve cmap and norm based on options and tag + opts["cmap"] = self._resolve_cmap(opts, tag) + opts["norm"] = self._resolve_norm(opts, tag) if regionname == "global": ax.set_global()