Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def run_score_map_pipeline(
"fig_size": cfg.get("fig_size", None),
"animation_format": cfg.get("animation_format", "gif"),
"fps": cfg.get("fps", 2),
"log_colorbar": cfg.get("log_colorbar", False),
}
output_basedir = str(reader.runplot_dir)
run_id = reader.run_id
Expand Down
142 changes: 123 additions & 19 deletions packages/evaluate/src/weathergen/evaluate/plotting/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# nor does it submit to any jurisdiction.

import datetime
import fnmatch
import logging
import os
import warnings
Expand Down Expand Up @@ -137,7 +138,6 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str |
self.dpi_val = plotter_cfg.get("dpi_val")
self.fig_size = plotter_cfg.get("fig_size")
self.fps = plotter_cfg.get("fps")
self.log_colorbar = plotter_cfg.get("log_colorbar", False)
self.regions = plotter_cfg.get("regions")
self.log_x = plotter_cfg.get("log_x", False)
self.log_y = plotter_cfg.get("log_y", False)
Expand Down Expand Up @@ -517,8 +517,8 @@ def create_maps_per_sample(
"""
self.update_data_selection(select)

# copy global plotting options, not specific to any variable
map_kwargs_global = {
# copy stream plotting options, not specific to any variable
map_kwargs_stream = {
key: value
for key, value in (map_kwargs or {}).items()
if not isinstance(value, oc.DictConfig)
Expand Down Expand Up @@ -574,7 +574,7 @@ def create_maps_per_sample(
var,
region,
tag=tag,
map_kwargs=dict(map_kwargs.get(var, {})) | map_kwargs_global,
map_kwargs=self._match_glob_kwargs(map_kwargs, var) | map_kwargs_stream,
title=self.get_map_title(var, valid_time, da_t),
)
plot_names.append(name)
Expand All @@ -583,6 +583,39 @@ def create_maps_per_sample(

return plot_names

# glob pattern resolution
@staticmethod
def _match_glob_kwargs(map_kwargs: dict | None, var: str) -> dict:
"""Resolve variable-specific plotting options, supporting exact and glob keys.

Keys under a stream section may be either:
- an exact channel name, e.g. "tp_imerg_0"
- a glob pattern containing one of ``* ? [ ]``, e.g. "tp_*"

All matching glob patterns are merged first (in config order), then the
exact channel key is merged on top, so an exact key overrides glob
values on key conflicts while non-conflicting keys from both are kept.
"""
if map_kwargs is None:
return {}

resolved: dict = {}

# merge glob/pattern keys
for key, value in map_kwargs.items():
if not isinstance(value, oc.DictConfig | dict):
continue
k = str(key)
if any(ch in k for ch in "*?[]") and fnmatch.fnmatch(var, k):
resolved |= dict(value)

# override with exact key if present
exact = map_kwargs.get(var)
if isinstance(exact, oc.DictConfig | dict):
resolved |= dict(exact)

return resolved

# map_kwargs parsing
@staticmethod
def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict:
Expand All @@ -594,9 +627,9 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict:
----------
map_kwargs : dict or None
Raw keyword arguments from the caller. Known keys (``marker_size``,
``scale_marker_size``, ``marker``, ``vmin``, ``vmax``, ``colormap``,
``use_datashader``, ``levels``, and HEALPix-related keys) are extracted;
remaining keys are collected under ``"extra"``.
``scale_marker_size``, ``marker``, ``vmin``, ``vmax``, ``colormap``,``colors``,
``use_datashader``, ``levels``, ``colorbar_scale`` and HEALPix-related keys) are
extracted; remaining keys are collected under ``"extra"``.
stream : str or None
Stream name used to look up the default marker size when
``marker_size`` is not provided in *map_kwargs*.
Expand All @@ -610,6 +643,7 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict:
- marker (str)
- vmin, vmax (float or None)
- cmap (matplotlib.colors.Colormap)
- colors (list or None)
- use_datashader (bool)
- norm (matplotlib.colors.Normalize or BoundaryNorm)
- add_healpix_grid (bool) and related healpix_* keys
Expand All @@ -625,9 +659,11 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict:
"marker": kw.pop("marker", "o"),
"vmin": kw.pop("vmin", None),
"vmax": kw.pop("vmax", None),
"cmap": plt.get_cmap(kw.pop("colormap", "coolwarm")),
"cmap": kw.pop("colormap", "coolwarm"),
"colors": kw.pop("colors", None),
"use_datashader": kw.pop("use_datashader", False),
"levels": kw.pop("levels", None),
"colorbar_scale": kw.pop("colorbar_scale", "linear"),
# HEALPix grid
"add_healpix_grid": kw.pop("add_healpix_grid", False),
"healpix_nside": kw.pop("healpix_nside", 4),
Expand All @@ -640,6 +676,82 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict:
parsed["extra"] = kw # remaining kwargs forwarded to scatter
return parsed

@staticmethod
def _resolve_cmap(opts: dict, tag: str) -> mpl.colors.Colormap:
"""Resolve the colormap for the plot.

Parameters
----------
opts : dict
Parsed map kwargs from ``_parse_map_kwargs``.
tag : str
Plot tag (e.g. ``'targets'``, ``'preds'``, ``'bias'``).

Returns
-------
matplotlib.colors.Colormap
The resolved colormap.
"""

# Bias maps always use coolwarm for visual consistency (overrides config)
if str(tag).startswith("bias"):
return plt.get_cmap("coolwarm")
# Explicit colors take precedence over colormap
elif isinstance(opts["colors"], oc.listconfig.ListConfig):
return mpl.colors.ListedColormap(list(opts["colors"]))
# Otherwise use the specified colormap (default "coolwarm")
else:
return plt.get_cmap(opts["cmap"])

@staticmethod
def _resolve_norm(opts: dict, tag: str) -> mpl.colors.Normalize:
"""Resolve the colorbar scale and build color normalisation.

Bias maps (``tag`` starting with ``"bias"``) are signed, so a non-linear
scale is coerced to ``"symlog"``. Explicit ``levels`` always take
precedence and yield a BoundaryNorm.

Parameters
----------
opts : dict
Parsed map kwargs from ``_parse_map_kwargs``.
tag : str
Plot tag (e.g. ``'targets'``, ``'preds'``, ``'bias'``).

Returns
-------
matplotlib.colors.Normalize
LogNorm, SymLogNorm, BoundaryNorm or linear Normalize.
"""

scale = opts["colorbar_scale"]
if scale not in {"linear", "log", "symlog"}:
_logger.warning("Unknown colorbar_scale=%r. Falling back to linear.", scale)
scale = "linear"

# Bias maps are always signed, use symlog instead of plain log
if str(tag).startswith("bias") and scale != "linear":
scale = "symlog"

vmin, vmax = opts["vmin"], opts["vmax"]

# Explicit levels override continuous norm for preds and targets
if isinstance(opts["levels"], oc.listconfig.ListConfig) and not str(tag).startswith("bias"):
return mpl.colors.BoundaryNorm(opts["levels"], opts["cmap"].N, extend="both")
# Log only if vmin > 0; otherwise fall back to linear with warning
if scale == "log" and vmin is not None and vmin > 0:
return mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
elif scale == "log" and vmin is not None and vmin <= 0:
_logger.warning(
"colorbar_scale='log' but vmin=%.3g <= 0; falling back to linear norm.",
vmin,
)
# Symlog (default for bias maps with non-linear scale)
if scale == "symlog" and vmin is not None and vmax is not None:
vmax_abs = max(abs(float(vmin)), abs(float(vmax)))
return mpl.colors.SymLogNorm(linthresh=max(vmax_abs * 1e-3, 1e-8), vmin=vmin, vmax=vmax)
return mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)

# rendering backends
@staticmethod
def _render_datashader(ax, proj, data, norm, cmap, marker_size_base):
Expand Down Expand Up @@ -909,17 +1021,9 @@ def scatter_plot(
if opts["vmax"] is None:
opts["vmax"] = float(p_hi)

if isinstance(opts["levels"], oc.listconfig.ListConfig):
opts["norm"] = mpl.colors.BoundaryNorm(opts["levels"], opts["cmap"].N, extend="both")
elif self.log_colorbar and opts["vmin"] is not None and opts["vmin"] > 0:
opts["norm"] = mpl.colors.LogNorm(vmin=opts["vmin"], vmax=opts["vmax"])
else:
if self.log_colorbar:
_logger.warning(
"log_colorbar=True but vmin=%.3g <= 0; falling back to linear norm.",
opts["vmin"],
)
opts["norm"] = mpl.colors.Normalize(vmin=opts["vmin"], vmax=opts["vmax"], clip=False)
# resolve cmap and norm based on options and tag
opts["cmap"] = self._resolve_cmap(opts, tag)
opts["norm"] = self._resolve_norm(opts, tag)

if regionname == "global":
ax.set_global()
Expand Down
30 changes: 19 additions & 11 deletions packages/evaluate/src/weathergen/evaluate/utils/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

"""Array / DataArray utility functions: range computation, coordinate helpers."""

import fnmatch

import numpy as np
import omegaconf as oc
import xarray as xr
Expand Down Expand Up @@ -88,21 +90,27 @@ def common_ranges(
"""
maps_config = global_plotting_opts_stream.copy()
for var in plot_chs:
if var in maps_config:
if not isinstance(maps_config[var].get("vmax"), (int | float)):
list_max = calc_bounds(data_tars, data_preds, var, "max")
list_max = np.concatenate([arr.flatten() for arr in list_max]).tolist()
maps_config[var].update({"vmax": float(max(list_max))})
if not isinstance(maps_config[var].get("vmin"), (int | float)):
list_min = calc_bounds(data_tars, data_preds, var, "min")
list_min = np.concatenate([arr.flatten() for arr in list_min]).tolist()
maps_config[var].update({"vmin": float(min(list_min))})
else:
if var not in maps_config:
maps_config[var] = {}
# override empty bounds with matching glob bounds from config
for key, value in maps_config.items():
if not isinstance(value, oc.DictConfig | dict):
continue
k = str(key)
if any(c in k for c in "*?[]") and fnmatch.fnmatch(var, k):
for bound in ("vmin", "vmax"):
if isinstance(value.get(bound), int | float):
maps_config[var].setdefault(bound, value[bound])
# if vmax still missing, compute bound from data
if not isinstance(maps_config[var].get("vmax"), (int | float)):
list_max = calc_bounds(data_tars, data_preds, var, "max")
list_max = np.concatenate([arr.flatten() for arr in list_max]).tolist()
maps_config[var].update({"vmax": float(max(list_max))})
# if vmin still missing, compute bound from data
if not isinstance(maps_config[var].get("vmin"), (int | float)):
list_min = calc_bounds(data_tars, data_preds, var, "min")
list_min = np.concatenate([arr.flatten() for arr in list_min]).tolist()
maps_config.update({var: {"vmax": float(max(list_max)), "vmin": float(min(list_min))}})
maps_config[var].update({"vmin": float(min(list_min))})
return maps_config


Expand Down
Loading