From f99016fbac751029b089731c02995e2066924805 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 13:57:07 +0200 Subject: [PATCH 01/12] feat: Add global configuration system (PlotDefaults, set_defaults, plot_context) Centralized plotting defaults so researchers can set preferences once instead of passing cmap/figsize/dpi to every function call. Supports context manager for temporary overrides and optional TOML config files. --- cfd_viz/__init__.py | 17 +++ cfd_viz/defaults.py | 205 ++++++++++++++++++++++++++++++++++++ tests/test_defaults.py | 228 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 450 insertions(+) create mode 100644 cfd_viz/defaults.py create mode 100644 tests/test_defaults.py diff --git a/cfd_viz/__init__.py b/cfd_viz/__init__.py index 53a9ebd..567e779 100644 --- a/cfd_viz/__init__.py +++ b/cfd_viz/__init__.py @@ -57,6 +57,15 @@ ) from .common import VTKData, ensure_dirs, find_vtk_files, read_vtk_file from .convert import from_cfd_python, from_simulation_result, to_cfd_python +from .defaults import ( + UNSET, + PlotDefaults, + get_defaults, + load_config_file, + plot_context, + reset_defaults, + set_defaults, +) from .info import get_recommended_settings, get_system_info, print_system_info from .quick import quick_plot, quick_plot_data, quick_plot_result from .stats import ( @@ -67,6 +76,14 @@ __all__ = [ "__version__", + # Configuration + "UNSET", + "PlotDefaults", + "get_defaults", + "set_defaults", + "reset_defaults", + "plot_context", + "load_config_file", # I/O "VTKData", "read_vtk_file", diff --git a/cfd_viz/defaults.py b/cfd_viz/defaults.py new file mode 100644 index 0000000..54799ee --- /dev/null +++ b/cfd_viz/defaults.py @@ -0,0 +1,205 @@ +"""Global plotting defaults for cfd-visualization. + +Provides a centralized PlotDefaults dataclass so researchers can +set preferences once instead of passing them to every function. + +Usage: + >>> import cfd_viz + >>> cfd_viz.set_defaults(cmap="coolwarm", dpi=200) + >>> cfd_viz.get_defaults().cmap + 'coolwarm' + + >>> with cfd_viz.plot_context(cmap="hot", dpi=300): + ... # all plots inside use "hot" colormap at 300 dpi + ... pass + >>> # defaults restored automatically +""" + +from __future__ import annotations + +import copy +import threading +from contextlib import contextmanager +from dataclasses import dataclass, fields as dc_fields +from pathlib import Path +from typing import Any, Iterator, Tuple + + +class _UnsetType: + """Sentinel for 'no value provided' (distinct from None).""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self): + return "UNSET" + + def __bool__(self): + return False + + +UNSET = _UnsetType() + + +@dataclass +class PlotDefaults: + """Global plotting default values. + + Attributes: + cmap: Default matplotlib colormap for sequential data. + diverging_cmap: Default matplotlib colormap for diverging data + (e.g., vorticity, field differences). + sequential_cmap: Secondary sequential colormap (e.g., pressure). + figsize: Default figure size (width, height) in inches. + dpi: Default dots per inch for saved figures. + levels: Default number of contour levels. + font_size: Default font size for titles and labels. + colorscale: Default Plotly colorscale for sequential data. + diverging_colorscale: Default Plotly colorscale for diverging data. + """ + + cmap: str = "viridis" + diverging_cmap: str = "RdBu_r" + sequential_cmap: str = "plasma" + figsize: Tuple[float, float] = (10, 8) + dpi: int = 150 + levels: int = 20 + font_size: int = 12 + colorscale: str = "Viridis" + diverging_colorscale: str = "RdBu" + + +_lock = threading.Lock() +_defaults = PlotDefaults() + + +def get_defaults() -> PlotDefaults: + """Return a copy of the current global defaults.""" + with _lock: + return copy.copy(_defaults) + + +def set_defaults(**kwargs: Any) -> None: + """Update global defaults. + + Args: + **kwargs: Field names from PlotDefaults and their new values. + + Raises: + TypeError: If an unknown field name is passed. + + Example: + >>> set_defaults(cmap="coolwarm", dpi=200) + """ + valid_fields = {f.name for f in dc_fields(PlotDefaults)} + unknown = set(kwargs) - valid_fields + if unknown: + raise TypeError( + f"Unknown defaults: {', '.join(sorted(unknown))}. " + f"Valid fields: {', '.join(sorted(valid_fields))}" + ) + with _lock: + for key, value in kwargs.items(): + setattr(_defaults, key, value) + + +def reset_defaults() -> None: + """Reset all defaults to their original values.""" + global _defaults # noqa: PLW0603 + with _lock: + _defaults = PlotDefaults() + + +def resolve(value: Any, field_name: str) -> Any: + """Resolve a parameter: return value if set, else the global default. + + Used inside plotting functions to resolve UNSET sentinels:: + + actual_cmap = resolve(cmap, "cmap") + + Args: + value: The caller-provided value (or UNSET). + field_name: The PlotDefaults field to fall back to. + + Returns: + The resolved value. + """ + if isinstance(value, _UnsetType): + return getattr(get_defaults(), field_name) + return value + + +@contextmanager +def plot_context(**kwargs: Any) -> Iterator[PlotDefaults]: + """Temporarily override defaults within a with-block. + + Args: + **kwargs: Fields to override temporarily. + + Yields: + The temporary PlotDefaults object. + + Example: + >>> with plot_context(cmap="coolwarm", dpi=300): + ... quick_plot(u, v, nx, ny) # uses coolwarm + >>> # defaults restored automatically + """ + global _defaults # noqa: PLW0603 + with _lock: + saved = copy.copy(_defaults) + try: + set_defaults(**kwargs) + yield get_defaults() + finally: + with _lock: + _defaults = saved + + +def load_config_file(path: str | None = None) -> bool: + """Load defaults from a TOML config file. + + Search order (when *path* is None): + 1. ``cfd_viz.toml`` in the current directory + 2. ``pyproject.toml`` ``[tool.cfd_viz.defaults]`` in the current directory + + Args: + path: Explicit path to a TOML file. + + Returns: + True if config was loaded, False otherwise. + """ + try: + import tomllib # type: ignore[import-not-found] + except ModuleNotFoundError: + try: + import tomli as tomllib # type: ignore[import-not-found,no-redef] + except ModuleNotFoundError: + return False + + if path is not None: + candidates = [Path(path)] + else: + candidates = [Path("cfd_viz.toml"), Path("pyproject.toml")] + + for filepath in candidates: + if not filepath.exists(): + continue + with open(filepath, "rb") as f: + data = tomllib.load(f) + + if filepath.name == "pyproject.toml": + defaults_data = data.get("tool", {}).get("cfd_viz", {}).get("defaults", {}) + else: + defaults_data = data.get("defaults", data) + + if defaults_data: + if "figsize" in defaults_data: + defaults_data["figsize"] = tuple(defaults_data["figsize"]) + set_defaults(**defaults_data) + return True + + return False diff --git a/tests/test_defaults.py b/tests/test_defaults.py new file mode 100644 index 0000000..290cd10 --- /dev/null +++ b/tests/test_defaults.py @@ -0,0 +1,228 @@ +"""Tests for cfd_viz.defaults module.""" + +import pytest + +from cfd_viz.defaults import ( + UNSET, + PlotDefaults, + _UnsetType, + get_defaults, + load_config_file, + plot_context, + reset_defaults, + resolve, + set_defaults, +) + + +@pytest.fixture(autouse=True) +def _clean_defaults(): + """Reset global defaults after every test.""" + yield + reset_defaults() + + +class TestPlotDefaults: + """Tests for the PlotDefaults dataclass.""" + + def test_default_values(self): + d = PlotDefaults() + assert d.cmap == "viridis" + assert d.diverging_cmap == "RdBu_r" + assert d.sequential_cmap == "plasma" + assert d.figsize == (10, 8) + assert d.dpi == 150 + assert d.levels == 20 + assert d.font_size == 12 + assert d.colorscale == "Viridis" + assert d.diverging_colorscale == "RdBu" + + def test_custom_values(self): + d = PlotDefaults(cmap="coolwarm", dpi=300) + assert d.cmap == "coolwarm" + assert d.dpi == 300 + assert d.levels == 20 # unchanged + + +class TestGetDefaults: + """Tests for get_defaults().""" + + def test_returns_plot_defaults_instance(self): + assert isinstance(get_defaults(), PlotDefaults) + + def test_returns_copy(self): + d = get_defaults() + d.cmap = "modified" + assert get_defaults().cmap == "viridis" + + +class TestSetDefaults: + """Tests for set_defaults().""" + + def test_updates_values(self): + set_defaults(cmap="coolwarm") + assert get_defaults().cmap == "coolwarm" + + def test_partial_update(self): + set_defaults(cmap="coolwarm") + assert get_defaults().dpi == 150 # unchanged + assert get_defaults().levels == 20 # unchanged + + def test_multiple_fields(self): + set_defaults(cmap="hot", dpi=300, levels=30) + d = get_defaults() + assert d.cmap == "hot" + assert d.dpi == 300 + assert d.levels == 30 + + def test_rejects_unknown_fields(self): + with pytest.raises(TypeError, match="Unknown defaults.*bogus"): + set_defaults(bogus="x") + + def test_rejects_multiple_unknown_fields(self): + with pytest.raises(TypeError, match="Unknown defaults"): + set_defaults(bogus="x", also_bogus="y") + + +class TestResetDefaults: + """Tests for reset_defaults().""" + + def test_restores_originals(self): + set_defaults(cmap="hot", dpi=300) + reset_defaults() + d = get_defaults() + assert d.cmap == "viridis" + assert d.dpi == 150 + + +class TestResolve: + """Tests for the resolve() helper.""" + + def test_returns_explicit_value(self): + assert resolve("coolwarm", "cmap") == "coolwarm" + + def test_returns_default_for_unset(self): + assert resolve(UNSET, "cmap") == "viridis" + + def test_respects_set_defaults(self): + set_defaults(cmap="hot") + assert resolve(UNSET, "cmap") == "hot" + + def test_explicit_overrides_set_defaults(self): + set_defaults(cmap="hot") + assert resolve("coolwarm", "cmap") == "coolwarm" + + def test_resolves_different_fields(self): + assert resolve(UNSET, "dpi") == 150 + assert resolve(UNSET, "levels") == 20 + assert resolve(UNSET, "figsize") == (10, 8) + + +class TestUnset: + """Tests for the UNSET sentinel.""" + + def test_is_singleton(self): + assert _UnsetType() is UNSET + + def test_is_falsy(self): + assert not UNSET + + def test_repr(self): + assert repr(UNSET) == "UNSET" + + +class TestPlotContext: + """Tests for the plot_context() context manager.""" + + def test_overrides_inside_context(self): + with plot_context(cmap="hot"): + assert get_defaults().cmap == "hot" + + def test_restores_on_exit(self): + set_defaults(cmap="coolwarm") + with plot_context(cmap="hot"): + pass + assert get_defaults().cmap == "coolwarm" + + def test_restores_on_exception(self): + set_defaults(cmap="coolwarm") + with pytest.raises(ValueError), plot_context(cmap="hot"): + raise ValueError("boom") + assert get_defaults().cmap == "coolwarm" + + def test_nested_contexts(self): + assert get_defaults().cmap == "viridis" + with plot_context(cmap="hot"): + assert get_defaults().cmap == "hot" + with plot_context(cmap="coolwarm"): + assert get_defaults().cmap == "coolwarm" + assert get_defaults().cmap == "hot" + assert get_defaults().cmap == "viridis" + + def test_yields_defaults(self): + with plot_context(cmap="hot", dpi=300) as d: + assert d.cmap == "hot" + assert d.dpi == 300 + + def test_rejects_unknown_fields(self): + with ( + pytest.raises(TypeError, match="Unknown defaults"), + plot_context(bogus="x"), + ): + pass + + +class TestLoadConfigFile: + """Tests for load_config_file().""" + + def test_nonexistent_returns_false(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + assert load_config_file() is False + + def test_load_cfd_viz_toml(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + (tmp_path / "cfd_viz.toml").write_text( + '[defaults]\ncmap = "coolwarm"\ndpi = 200\n' + ) + assert load_config_file() is True + d = get_defaults() + assert d.cmap == "coolwarm" + assert d.dpi == 200 + + def test_load_pyproject_toml(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + (tmp_path / "pyproject.toml").write_text( + '[tool.cfd_viz.defaults]\ncmap = "hot"\nlevels = 30\n' + ) + assert load_config_file() is True + d = get_defaults() + assert d.cmap == "hot" + assert d.levels == 30 + + def test_cfd_viz_toml_takes_priority(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + (tmp_path / "cfd_viz.toml").write_text('[defaults]\ncmap = "coolwarm"\n') + (tmp_path / "pyproject.toml").write_text( + '[tool.cfd_viz.defaults]\ncmap = "hot"\n' + ) + load_config_file() + assert get_defaults().cmap == "coolwarm" + + def test_figsize_tuple_conversion(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + (tmp_path / "cfd_viz.toml").write_text("[defaults]\nfigsize = [12, 8]\n") + load_config_file() + assert get_defaults().figsize == (12, 8) + + def test_explicit_path(self, tmp_path): + config = tmp_path / "custom.toml" + config.write_text('cmap = "magma"\ndpi = 72\n') + assert load_config_file(str(config)) is True + d = get_defaults() + assert d.cmap == "magma" + assert d.dpi == 72 + + def test_pyproject_without_section_returns_false(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + (tmp_path / "pyproject.toml").write_text("[tool.other]\nfoo = 1\n") + assert load_config_file() is False From e61d9004e26ad200b312c7510c96ae3893942c69 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 14:00:14 +0200 Subject: [PATCH 02/12] refactor: Migrate plotting/fields.py to use global defaults Replace hardcoded cmap/levels defaults with UNSET sentinel and resolve() calls so they respect set_defaults() and plot_context(). --- cfd_viz/plotting/fields.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/cfd_viz/plotting/fields.py b/cfd_viz/plotting/fields.py index 4d2164b..1505e6d 100644 --- a/cfd_viz/plotting/fields.py +++ b/cfd_viz/plotting/fields.py @@ -12,14 +12,16 @@ from matplotlib.axes import Axes from numpy.typing import NDArray +from cfd_viz.defaults import UNSET, resolve + def plot_contour_field( X: NDArray, Y: NDArray, field: NDArray, ax: Optional[Axes] = None, - levels: int = 20, - cmap: str = "viridis", + levels: int = UNSET, + cmap: str = UNSET, colorbar: bool = True, colorbar_label: str = "", title: str = "", @@ -48,6 +50,9 @@ def plot_contour_field( Returns: The matplotlib axes object. """ + levels = resolve(levels, "levels") + cmap = resolve(cmap, "cmap") + if ax is None: _, ax = plt.subplots() @@ -71,8 +76,8 @@ def plot_velocity_field( u: NDArray, v: NDArray, ax: Optional[Axes] = None, - levels: int = 20, - cmap: str = "viridis", + levels: int = UNSET, + cmap: str = UNSET, colorbar: bool = True, title: str = "Velocity Magnitude", **kwargs, @@ -94,6 +99,9 @@ def plot_velocity_field( Returns: The matplotlib axes object. """ + levels = resolve(levels, "levels") + cmap = resolve(cmap, "cmap") + velocity_mag = np.sqrt(u**2 + v**2) return plot_contour_field( X, @@ -114,8 +122,8 @@ def plot_pressure_field( Y: NDArray, pressure: NDArray, ax: Optional[Axes] = None, - levels: int = 20, - cmap: str = "plasma", + levels: int = UNSET, + cmap: str = UNSET, colorbar: bool = True, title: str = "Pressure Field", **kwargs, @@ -136,6 +144,9 @@ def plot_pressure_field( Returns: The matplotlib axes object. """ + levels = resolve(levels, "levels") + cmap = resolve(cmap, "sequential_cmap") + return plot_contour_field( X, Y, @@ -155,8 +166,8 @@ def plot_vorticity_field( Y: NDArray, omega: NDArray, ax: Optional[Axes] = None, - levels: int = 20, - cmap: str = "RdBu_r", + levels: int = UNSET, + cmap: str = UNSET, colorbar: bool = True, title: str = "Vorticity Field", symmetric: bool = True, @@ -179,6 +190,9 @@ def plot_vorticity_field( Returns: The matplotlib axes object. """ + levels = resolve(levels, "levels") + cmap = resolve(cmap, "diverging_cmap") + if ax is None: _, ax = plt.subplots() @@ -314,8 +328,8 @@ def plot_vorticity_with_streamlines( u: NDArray, v: NDArray, ax: Optional[Axes] = None, - vort_levels: int = 20, - vort_cmap: str = "RdBu_r", + vort_levels: int = UNSET, + vort_cmap: str = UNSET, vort_alpha: float = 0.8, stream_density: float = 1.5, stream_color: str = "black", @@ -344,6 +358,9 @@ def plot_vorticity_with_streamlines( Returns: The matplotlib axes object. """ + vort_levels = resolve(vort_levels, "levels") + vort_cmap = resolve(vort_cmap, "diverging_cmap") + if ax is None: _, ax = plt.subplots() From a4c9bb5ca214575b74cbbcf1838fea3f2093eb90 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 14:07:55 +0200 Subject: [PATCH 03/12] refactor: Migrate plotting/analysis.py and line_plots.py to global defaults Replace hardcoded cmap/levels/figsize/fontsize/colormap defaults with UNSET sentinel and resolve() calls. --- cfd_viz/plotting/analysis.py | 27 ++++++++++++++++++++------- cfd_viz/plotting/line_plots.py | 13 ++++++++++--- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/cfd_viz/plotting/analysis.py b/cfd_viz/plotting/analysis.py index 529506f..cd7c12c 100644 --- a/cfd_viz/plotting/analysis.py +++ b/cfd_viz/plotting/analysis.py @@ -13,6 +13,7 @@ from cfd_viz.analysis.case_comparison import CaseComparison, FieldDifference from cfd_viz.analysis.flow_features import SpatialFluctuations, WakeRegion +from cfd_viz.defaults import UNSET, resolve def plot_field_difference( @@ -20,8 +21,8 @@ def plot_field_difference( X: NDArray, Y: NDArray, ax: Optional[Axes] = None, - cmap: str = "RdBu_r", - levels: int = 20, + cmap: str = UNSET, + levels: int = UNSET, colorbar: bool = True, title: Optional[str] = None, **kwargs, @@ -42,6 +43,9 @@ def plot_field_difference( Returns: The matplotlib axes object. """ + cmap = resolve(cmap, "diverging_cmap") + levels = resolve(levels, "levels") + if ax is None: _, ax = plt.subplots() @@ -67,7 +71,7 @@ def plot_case_comparison( comparison: CaseComparison, X: NDArray, Y: NDArray, - figsize: tuple = (15, 10), + figsize: tuple = UNSET, **kwargs, ) -> plt.Figure: """Create a multi-panel comparison plot for two CFD cases. @@ -82,6 +86,8 @@ def plot_case_comparison( Returns: The matplotlib figure object. """ + figsize = resolve(figsize, "figsize") + fig, axes = plt.subplots(2, 3, figsize=figsize) axes = axes.flatten() @@ -182,7 +188,7 @@ def plot_wake_region( Y: NDArray, velocity_mag: NDArray, ax: Optional[Axes] = None, - field_cmap: str = "viridis", + field_cmap: str = UNSET, field_alpha: float = 0.7, wake_color: str = "red", wake_linewidth: float = 2, @@ -207,6 +213,8 @@ def plot_wake_region( Returns: The matplotlib axes object. """ + field_cmap = resolve(field_cmap, "cmap") + if ax is None: _, ax = plt.subplots() @@ -246,8 +254,8 @@ def plot_spatial_fluctuations( X: NDArray, Y: NDArray, ax: Optional[Axes] = None, - cmap: str = "hot", - levels: int = 15, + cmap: str = UNSET, + levels: int = UNSET, colorbar: bool = True, title: Optional[str] = None, **kwargs, @@ -268,6 +276,9 @@ def plot_spatial_fluctuations( Returns: The matplotlib axes object. """ + cmap = resolve(cmap, "cmap") + levels = resolve(levels, "levels") + if ax is None: _, ax = plt.subplots() @@ -290,7 +301,7 @@ def plot_flow_statistics( stats: Dict[str, float], ax: Optional[Axes] = None, title: str = "Flow Statistics", - fontsize: int = 10, + fontsize: int = UNSET, **kwargs, ) -> Axes: """Plot a text panel with flow statistics. @@ -305,6 +316,8 @@ def plot_flow_statistics( Returns: The matplotlib axes object. """ + fontsize = resolve(fontsize, "font_size") + if ax is None: _, ax = plt.subplots() diff --git a/cfd_viz/plotting/line_plots.py b/cfd_viz/plotting/line_plots.py index 5c5527f..50cd24c 100644 --- a/cfd_viz/plotting/line_plots.py +++ b/cfd_viz/plotting/line_plots.py @@ -18,6 +18,7 @@ ) from cfd_viz.analysis.flow_features import CrossSectionalAverages from cfd_viz.analysis.line_extraction import LineProfile, MultipleProfiles +from cfd_viz.defaults import UNSET, resolve def plot_line_profile( @@ -96,7 +97,7 @@ def plot_multiple_profiles( profiles: MultipleProfiles, ax: Optional[Axes] = None, plot_type: str = "magnitude", - colormap: str = "viridis", + colormap: str = UNSET, title: str = "Multiple Line Profiles", **kwargs, ) -> Axes: @@ -113,6 +114,8 @@ def plot_multiple_profiles( Returns: The matplotlib axes object. """ + colormap = resolve(colormap, "cmap") + if ax is None: _, ax = plt.subplots() @@ -155,7 +158,7 @@ def plot_velocity_profiles( y: NDArray, u_profiles: Sequence[NDArray], ax: Optional[Axes] = None, - colormap: str = "viridis", + colormap: str = UNSET, title: str = "Velocity Profiles at Different Stations", **kwargs, ) -> Axes: @@ -173,6 +176,8 @@ def plot_velocity_profiles( Returns: The matplotlib axes object. """ + colormap = resolve(colormap, "cmap") + if ax is None: _, ax = plt.subplots() @@ -266,7 +271,7 @@ def plot_boundary_layer_profiles( profiles: List[BoundaryLayerProfile], ax: Optional[Axes] = None, normalized: bool = True, - colormap: str = "plasma", + colormap: str = UNSET, title: str = "Boundary Layer Profiles", **kwargs, ) -> Axes: @@ -283,6 +288,8 @@ def plot_boundary_layer_profiles( Returns: The matplotlib axes object. """ + colormap = resolve(colormap, "sequential_cmap") + if ax is None: _, ax = plt.subplots() From a21845d46b0e9c74f4e54f81db0a9891cbde474d Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 14:18:42 +0200 Subject: [PATCH 04/12] refactor: Migrate plotting/time_series.py and quick.py to global defaults Replace hardcoded figsize/cmap/levels with UNSET+resolve() or get_defaults() for inline values. --- cfd_viz/plotting/time_series.py | 31 ++++++++++++++++++++++--------- cfd_viz/quick.py | 9 +++++++-- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/cfd_viz/plotting/time_series.py b/cfd_viz/plotting/time_series.py index ef4850e..65525dc 100644 --- a/cfd_viz/plotting/time_series.py +++ b/cfd_viz/plotting/time_series.py @@ -19,6 +19,7 @@ FlowMetricsTimeSeries, TemporalStatistics, ) +from cfd_viz.defaults import UNSET, get_defaults, resolve def plot_metric_time_series( @@ -84,7 +85,7 @@ def plot_metric_time_series( def plot_convergence_history( history: FlowMetricsTimeSeries, metrics: Sequence[str] = ("max_velocity", "mean_velocity", "total_kinetic_energy"), - figsize: tuple = (15, 4), + figsize: tuple = UNSET, **kwargs, ) -> Figure: """Plot convergence history for multiple metrics. @@ -98,6 +99,8 @@ def plot_convergence_history( Returns: The matplotlib figure object. """ + figsize = resolve(figsize, "figsize") + fig, axes = plt.subplots(1, len(metrics), figsize=figsize) if len(metrics) == 1: axes = [axes] @@ -138,7 +141,7 @@ def plot_monitoring_dashboard( Y: NDArray, velocity_mag: NDArray, pressure: NDArray, - figsize: tuple = (18, 10), + figsize: tuple = UNSET, **kwargs, ) -> Figure: """Create a real-time monitoring dashboard. @@ -156,12 +159,17 @@ def plot_monitoring_dashboard( Returns: The matplotlib figure object. """ + figsize = resolve(figsize, "figsize") + defaults = get_defaults() + fig, axes = plt.subplots(2, 3, figsize=figsize) fig.suptitle("CFD Real-time Monitoring Dashboard", fontsize=16) axes = axes.flatten() # 1. Velocity magnitude field - cs1 = axes[0].contourf(X, Y, velocity_mag, levels=20, cmap="viridis") + cs1 = axes[0].contourf( + X, Y, velocity_mag, levels=defaults.levels, cmap=defaults.cmap + ) plt.colorbar(cs1, ax=axes[0], label="Velocity (m/s)") axes[0].set_title("Velocity Field") axes[0].set_xlabel("x (m)") @@ -169,7 +177,9 @@ def plot_monitoring_dashboard( axes[0].set_aspect("equal") # 2. Pressure field - cs2 = axes[1].contourf(X, Y, pressure, levels=20, cmap="plasma") + cs2 = axes[1].contourf( + X, Y, pressure, levels=defaults.levels, cmap=defaults.sequential_cmap + ) plt.colorbar(cs2, ax=axes[1], label="Pressure") axes[1].set_title("Pressure Field") axes[1].set_xlabel("x (m)") @@ -299,7 +309,7 @@ def plot_temporal_statistics( stats: TemporalStatistics, X: NDArray, Y: NDArray, - figsize: tuple = (15, 5), + figsize: tuple = UNSET, **kwargs, ) -> Figure: """Plot temporal statistics (mean, std, min, max fields). @@ -314,28 +324,31 @@ def plot_temporal_statistics( Returns: The matplotlib figure object. """ + figsize = resolve(figsize, "figsize") + defaults = get_defaults() + fig, axes = plt.subplots(1, 4, figsize=figsize) # Mean field - cs1 = axes[0].contourf(X, Y, stats.mean, levels=20, cmap="viridis") + cs1 = axes[0].contourf(X, Y, stats.mean, levels=defaults.levels, cmap=defaults.cmap) plt.colorbar(cs1, ax=axes[0], label="Mean") axes[0].set_title("Time-Averaged Field") axes[0].set_aspect("equal") # Standard deviation - cs2 = axes[1].contourf(X, Y, stats.std, levels=20, cmap="hot") + cs2 = axes[1].contourf(X, Y, stats.std, levels=defaults.levels, cmap="hot") plt.colorbar(cs2, ax=axes[1], label="Std Dev") axes[1].set_title("Standard Deviation") axes[1].set_aspect("equal") # Minimum - cs3 = axes[2].contourf(X, Y, stats.min, levels=20, cmap="viridis") + cs3 = axes[2].contourf(X, Y, stats.min, levels=defaults.levels, cmap=defaults.cmap) plt.colorbar(cs3, ax=axes[2], label="Min") axes[2].set_title("Minimum Values") axes[2].set_aspect("equal") # Maximum - cs4 = axes[3].contourf(X, Y, stats.max, levels=20, cmap="viridis") + cs4 = axes[3].contourf(X, Y, stats.max, levels=defaults.levels, cmap=defaults.cmap) plt.colorbar(cs4, ax=axes[3], label="Max") axes[3].set_title("Maximum Values") axes[3].set_aspect("equal") diff --git a/cfd_viz/quick.py b/cfd_viz/quick.py index 84ed562..5bc171e 100644 --- a/cfd_viz/quick.py +++ b/cfd_viz/quick.py @@ -10,6 +10,7 @@ from matplotlib.figure import Figure from .convert import from_cfd_python +from .defaults import UNSET, resolve from .fields import magnitude, vorticity from .plotting import plot_contour_field @@ -28,7 +29,7 @@ def quick_plot( ymin: float = 0.0, ymax: float = 1.0, ax: Optional[Axes] = None, - figsize: tuple[float, float] = (8, 6), + figsize: tuple[float, float] = UNSET, **kwargs: Any, ) -> tuple[Figure, Axes]: """Quick visualization of cfd-python simulation results. @@ -58,6 +59,8 @@ def quick_plot( >>> fig, ax = quick_plot(result['u'], result['v'], result['nx'], result['ny']) >>> plt.show() """ + figsize = resolve(figsize, "figsize") + data = from_cfd_python( u, v, nx=nx, ny=ny, p=p, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax ) @@ -149,7 +152,7 @@ def quick_plot_data( data: Any, field: FieldType = "velocity_magnitude", ax: Optional[Axes] = None, - figsize: tuple[float, float] = (8, 6), + figsize: tuple[float, float] = UNSET, **kwargs: Any, ) -> tuple[Figure, Axes]: """Quick visualization of VTKData object. @@ -170,6 +173,8 @@ def quick_plot_data( >>> fig, ax = quick_plot_data(data, field="velocity_magnitude") >>> plt.show() """ + figsize = resolve(figsize, "figsize") + if ax is None: fig, ax = plt.subplots(figsize=figsize) else: From c5fa5d50370400e90936e5e53598835f2e58e2a4 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 15:13:21 +0200 Subject: [PATCH 05/12] refactor: Migrate animation/renderers.py and export.py to global defaults Replace hardcoded cmap/levels/figsize/dpi defaults with UNSET+resolve() for parameters and get_defaults() for inline values. --- cfd_viz/animation/export.py | 38 ++++++++++++------- cfd_viz/animation/renderers.py | 68 +++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 38 deletions(-) diff --git a/cfd_viz/animation/export.py b/cfd_viz/animation/export.py index 20821ef..416a8d9 100644 --- a/cfd_viz/animation/export.py +++ b/cfd_viz/animation/export.py @@ -11,6 +11,8 @@ from matplotlib import animation from matplotlib.figure import Figure +from cfd_viz.defaults import UNSET, get_defaults, resolve + from .frames import AnimationFrames, FrameData from .renderers import create_velocity_colormap @@ -18,8 +20,8 @@ def export_frame_to_image( frame: FrameData, output_path: Union[str, Path], - figsize: Tuple[int, int] = (16, 12), - dpi: int = 150, + figsize: Tuple[int, int] = UNSET, + dpi: int = UNSET, include_vectors: bool = True, include_streamlines: bool = True, include_pressure: bool = True, @@ -41,6 +43,9 @@ def export_frame_to_image( include_streamlines: Whether to include streamlines subplot. include_pressure: Whether to include pressure subplot. """ + figsize = resolve(figsize, "figsize") + dpi = resolve(dpi, "dpi") + defaults = get_defaults() fig, axes = plt.subplots(2, 2, figsize=figsize) axes = axes.flatten() @@ -55,7 +60,7 @@ def export_frame_to_image( # Panel 1: Velocity magnitude contours if velocity_mag is not None: - im1 = axes[0].contourf(X, Y, velocity_mag, levels=20, cmap="viridis") + im1 = axes[0].contourf(X, Y, velocity_mag, levels=20, cmap=defaults.cmap) axes[0].set_title("Velocity Magnitude") axes[0].axis("equal") plt.colorbar(im1, ax=axes[0]) @@ -77,7 +82,7 @@ def export_frame_to_image( u_sub, v_sub, np.sqrt(u_sub**2 + v_sub**2), - cmap="plasma", + cmap=defaults.sequential_cmap, scale_units="xy", angles="xy", ) @@ -91,14 +96,14 @@ def export_frame_to_image( try: if velocity_mag is not None: axes[2].streamplot( - X, Y, u, v, color=velocity_mag, cmap="viridis", density=2 + X, Y, u, v, color=velocity_mag, cmap=defaults.cmap, density=2 ) else: axes[2].streamplot(X, Y, u, v, density=2) except ValueError: # Streamplot can fail with certain grid configurations if velocity_mag is not None: - axes[2].contourf(X, Y, velocity_mag, levels=20, cmap="viridis") + axes[2].contourf(X, Y, velocity_mag, levels=20, cmap=defaults.cmap) axes[2].set_title("Flow Streamlines") axes[2].axis("equal") else: @@ -106,12 +111,14 @@ def export_frame_to_image( # Panel 4: Pressure or combined view if include_pressure and p is not None: - im4 = axes[3].contourf(X, Y, p, levels=20, cmap="RdBu_r") + im4 = axes[3].contourf(X, Y, p, levels=20, cmap=defaults.diverging_cmap) axes[3].set_title("Pressure Field") axes[3].axis("equal") plt.colorbar(im4, ax=axes[3]) elif velocity_mag is not None and u is not None and v is not None: - im4 = axes[3].contourf(X, Y, velocity_mag, levels=20, cmap="viridis", alpha=0.7) + im4 = axes[3].contourf( + X, Y, velocity_mag, levels=20, cmap=defaults.cmap, alpha=0.7 + ) axes[3].streamplot(X, Y, u, v, color="white", density=1, linewidth=0.8) axes[3].set_title("Combined: Magnitude + Streamlines") plt.colorbar(im4, ax=axes[3]) @@ -129,8 +136,8 @@ def export_animation_frames( animation_frames: AnimationFrames, output_dir: Union[str, Path], prefix: str = "frame", - figsize: Tuple[int, int] = (16, 12), - dpi: int = 150, + figsize: Tuple[int, int] = UNSET, + dpi: int = UNSET, ) -> List[str]: """Export all frames from AnimationFrames to image files. @@ -144,6 +151,8 @@ def export_animation_frames( Returns: List of exported file paths. """ + figsize = resolve(figsize, "figsize") + dpi = resolve(dpi, "dpi") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -164,7 +173,7 @@ def save_animation( output_path: Union[str, Path], writer: str = "pillow", fps: int = 5, - dpi: int = 100, + dpi: int = UNSET, ) -> None: """Save a matplotlib animation to file. @@ -175,6 +184,7 @@ def save_animation( fps: Frames per second. dpi: Resolution. """ + dpi = resolve(dpi, "dpi") output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) @@ -183,7 +193,7 @@ def save_animation( def create_comprehensive_frame_figure( frame: FrameData, - figsize: Tuple[int, int] = (18, 10), + figsize: Tuple[int, int] = UNSET, ) -> Figure: """Create a comprehensive 2x3 figure for a single frame. @@ -194,6 +204,8 @@ def create_comprehensive_frame_figure( Returns: Matplotlib Figure object. """ + figsize = resolve(figsize, "figsize") + defaults = get_defaults() fig, axes = plt.subplots(2, 3, figsize=figsize) axes = axes.flatten() @@ -215,7 +227,7 @@ def create_comprehensive_frame_figure( plt.colorbar(im0, ax=axes[0]) # Pressure - im1 = axes[1].contourf(X, Y, p, levels=20, cmap="RdBu_r") + im1 = axes[1].contourf(X, Y, p, levels=20, cmap=defaults.diverging_cmap) contours = axes[1].contour(X, Y, p, levels=8, colors="black", linewidths=0.5) axes[1].clabel(contours, inline=True, fontsize=8) axes[1].set_title("Pressure Field", fontweight="bold") diff --git a/cfd_viz/animation/renderers.py b/cfd_viz/animation/renderers.py index b69c82a..d3ca403 100644 --- a/cfd_viz/animation/renderers.py +++ b/cfd_viz/animation/renderers.py @@ -14,6 +14,8 @@ from matplotlib.colors import LinearSegmentedColormap from matplotlib.figure import Figure +from cfd_viz.defaults import UNSET, get_defaults, resolve + from .frames import AnimationFrames, FrameData, ParticleTraces # ============================================================================= @@ -50,8 +52,8 @@ def render_contour_frame( ax: Axes, frame: FrameData, field_name: str = "velocity_mag", - levels: int = 20, - cmap: str = "viridis", + levels: int = UNSET, + cmap: str = UNSET, vmin: Optional[float] = None, vmax: Optional[float] = None, colorbar: bool = True, @@ -70,6 +72,8 @@ def render_contour_frame( colorbar: Whether to add colorbar. title: Plot title. """ + levels = resolve(levels, "levels") + cmap = resolve(cmap, "cmap") ax.clear() field = frame.fields.get(field_name) @@ -107,7 +111,7 @@ def render_vector_frame( frame: FrameData, subsample: int = 5, color_by: str = "velocity_mag", - cmap: str = "viridis", + cmap: str = UNSET, scale: float = 15, width: float = 0.003, title: Optional[str] = None, @@ -124,6 +128,7 @@ def render_vector_frame( width: Arrow width. title: Plot title. """ + cmap = resolve(cmap, "cmap") ax.clear() u = frame.fields.get("u") @@ -170,7 +175,7 @@ def render_streamline_frame( ax: Axes, frame: FrameData, color_by: str = "velocity_mag", - cmap: str = "viridis", + cmap: str = UNSET, density: float = 2.0, linewidth: float = 1.0, title: Optional[str] = None, @@ -186,6 +191,7 @@ def render_streamline_frame( linewidth: Line width. title: Plot title. """ + cmap = resolve(cmap, "cmap") ax.clear() u = frame.fields.get("u") @@ -229,9 +235,9 @@ def render_streamline_frame( def create_field_animation( animation_frames: AnimationFrames, field_name: str = "velocity_mag", - figsize: Tuple[int, int] = (12, 6), - cmap: str = "viridis", - levels: int = 20, + figsize: Tuple[int, int] = UNSET, + cmap: str = UNSET, + levels: int = UNSET, interval: int = 200, title_prefix: str = "CFD Simulation", ) -> Tuple[Figure, animation.FuncAnimation]: @@ -249,6 +255,9 @@ def create_field_animation( Returns: Tuple of (figure, animation) objects. """ + figsize = resolve(figsize, "figsize") + cmap = resolve(cmap, "cmap") + levels = resolve(levels, "levels") if not animation_frames.frames: raise ValueError("No frames to animate") @@ -315,8 +324,8 @@ def animate(frame_idx: int) -> List: def create_streamline_animation( animation_frames: AnimationFrames, - figsize: Tuple[int, int] = (12, 6), - cmap: str = "viridis", + figsize: Tuple[int, int] = UNSET, + cmap: str = UNSET, density: float = 1.5, interval: int = 200, title_prefix: str = "CFD Streamlines", @@ -334,6 +343,8 @@ def create_streamline_animation( Returns: Tuple of (figure, animation) objects. """ + figsize = resolve(figsize, "figsize") + cmap = resolve(cmap, "cmap") if not animation_frames.frames: raise ValueError("No frames to animate") @@ -394,8 +405,8 @@ def animate(frame_idx: int) -> List: def create_vector_animation( animation_frames: AnimationFrames, - figsize: Tuple[int, int] = (12, 6), - cmap: str = "viridis", + figsize: Tuple[int, int] = UNSET, + cmap: str = UNSET, subsample: int = 5, scale: float = 15, interval: int = 200, @@ -415,6 +426,8 @@ def create_vector_animation( Returns: Tuple of (figure, animation) objects. """ + figsize = resolve(figsize, "figsize") + cmap = resolve(cmap, "cmap") if not animation_frames.frames: raise ValueError("No frames to animate") @@ -470,7 +483,7 @@ def animate(frame_idx: int) -> List: def create_multi_panel_animation( animation_frames: AnimationFrames, - figsize: Tuple[int, int] = (18, 10), + figsize: Tuple[int, int] = UNSET, interval: int = 500, title: str = "CFD Flow Analysis Dashboard", ) -> Tuple[Figure, animation.FuncAnimation]: @@ -489,6 +502,8 @@ def create_multi_panel_animation( Returns: Tuple of (figure, animation) objects. """ + figsize = resolve(figsize, "figsize") + defaults = get_defaults() if not animation_frames.frames: raise ValueError("No frames to animate") @@ -501,8 +516,6 @@ def create_multi_panel_animation( p_range = animation_frames.get_field_range("p") vort_range = animation_frames.get_field_range("vorticity") - velocity_cmap = create_velocity_colormap() - first_frame = animation_frames.frames[0] X, Y = first_frame.X, first_frame.Y @@ -525,7 +538,7 @@ def animate(frame_idx: int) -> List: aspect="auto", vmin=vel_range[0], vmax=vel_range[1], - cmap=velocity_cmap, + cmap=defaults.cmap, ) axes[0].set_title("Velocity Magnitude", fontweight="bold") axes[0].set_xlabel("X") @@ -575,7 +588,7 @@ def animate(frame_idx: int) -> List: u_sub, v_sub, vel_sub, - cmap=velocity_cmap, + cmap=defaults.cmap, scale=15, width=0.003, ) @@ -587,7 +600,7 @@ def animate(frame_idx: int) -> List: # Streamlines axes[4].streamplot( - X, Y, u, v, color=vel_mag, cmap=velocity_cmap, density=2, linewidth=1.5 + X, Y, u, v, color=vel_mag, cmap=defaults.cmap, density=2, linewidth=1.5 ) axes[4].set_xlim(X.min(), X.max()) axes[4].set_ylim(Y.min(), Y.max()) @@ -603,7 +616,7 @@ def animate(frame_idx: int) -> List: aspect="auto", vmin=vel_range[0], vmax=vel_range[1], - cmap=velocity_cmap, + cmap=defaults.cmap, alpha=0.8, ) axes[5].contour(X, Y, p, levels=6, colors="white", linewidths=1.5) @@ -639,7 +652,7 @@ def animate(frame_idx: int) -> List: def create_particle_trace_animation( animation_frames: AnimationFrames, particle_traces: ParticleTraces, - figsize: Tuple[int, int] = (15, 8), + figsize: Tuple[int, int] = UNSET, interval: int = 100, title_prefix: str = "Particle Traces", ) -> Tuple[Figure, animation.FuncAnimation]: @@ -655,6 +668,7 @@ def create_particle_trace_animation( Returns: Tuple of (figure, animation) objects. """ + figsize = resolve(figsize, "figsize") if not animation_frames.frames: raise ValueError("No frames to animate") @@ -753,7 +767,7 @@ def animate(frame_idx: int) -> List: def create_vorticity_analysis_animation( animation_frames: AnimationFrames, - figsize: Tuple[int, int] = (15, 6), + figsize: Tuple[int, int] = UNSET, interval: int = 400, title_prefix: str = "Vorticity Analysis", ) -> Tuple[Figure, animation.FuncAnimation]: @@ -768,6 +782,8 @@ def create_vorticity_analysis_animation( Returns: Tuple of (figure, animation) objects. """ + figsize = resolve(figsize, "figsize") + defaults = get_defaults() if not animation_frames.frames: raise ValueError("No frames to animate") @@ -798,7 +814,7 @@ def animate(frame_idx: int) -> List: aspect="auto", vmin=vort_range[0], vmax=vort_range[1], - cmap="RdBu", + cmap=defaults.diverging_cmap, ) ax1.set_title("Vorticity Field", fontweight="bold") ax1.set_xlabel("X") @@ -819,7 +835,7 @@ def animate(frame_idx: int) -> List: v, color=vorticity, linewidth=lw, - cmap="RdBu", + cmap=defaults.diverging_cmap, density=2, ) ax2.set_title("Streamlines Colored by Vorticity", fontweight="bold") @@ -850,7 +866,7 @@ def animate(frame_idx: int) -> List: def create_3d_surface_animation( animation_frames: AnimationFrames, field_name: str = "velocity_mag", - figsize: Tuple[int, int] = (12, 8), + figsize: Tuple[int, int] = UNSET, interval: int = 200, rotate_camera: bool = True, title_prefix: str = "3D Surface", @@ -868,6 +884,8 @@ def create_3d_surface_animation( Returns: Tuple of (figure, animation) objects. """ + figsize = resolve(figsize, "figsize") + defaults = get_defaults() if not animation_frames.frames: raise ValueError("No frames to animate") @@ -890,7 +908,7 @@ def animate(frame_idx: int) -> List: X, Y, field, - cmap="viridis", + cmap=defaults.cmap, vmin=vmin, vmax=vmax, alpha=0.8, @@ -905,7 +923,7 @@ def animate(frame_idx: int) -> List: field, zdir="z", offset=vmin - 0.1 * (vmax - vmin), - cmap="viridis", + cmap=defaults.cmap, alpha=0.5, ) From 26f432ec50c8430a4b9579e647528a51005f7c61 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 15:24:07 +0200 Subject: [PATCH 06/12] refactor: Migrate interactive/plotly.py to global defaults Replace hardcoded Plotly colorscale defaults with UNSET+resolve() for parameters and get_defaults() for inline colorscale values. --- cfd_viz/interactive/plotly.py | 103 ++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 23 deletions(-) diff --git a/cfd_viz/interactive/plotly.py b/cfd_viz/interactive/plotly.py index 68e47de..1444449 100644 --- a/cfd_viz/interactive/plotly.py +++ b/cfd_viz/interactive/plotly.py @@ -12,6 +12,8 @@ import plotly.subplots as sp from numpy.typing import NDArray +from cfd_viz.defaults import UNSET, get_defaults, resolve + @dataclass class InteractiveFrameData: @@ -137,7 +139,7 @@ def create_heatmap_figure( y: NDArray, field: NDArray, title: str = "Field", - colorscale: str = "Viridis", + colorscale: str = UNSET, height: int = 500, width: int = 700, ) -> go.Figure: @@ -155,6 +157,8 @@ def create_heatmap_figure( Returns: Plotly Figure object. """ + colorscale = resolve(colorscale, "colorscale") + fig = go.Figure( data=go.Heatmap( z=field, @@ -182,7 +186,7 @@ def create_contour_figure( y: NDArray, field: NDArray, title: str = "Contour", - colorscale: str = "Viridis", + colorscale: str = UNSET, ncontours: int = 20, height: int = 500, width: int = 700, @@ -202,6 +206,8 @@ def create_contour_figure( Returns: Plotly Figure object. """ + colorscale = resolve(colorscale, "colorscale") + fig = go.Figure( data=go.Contour( z=field, @@ -232,7 +238,7 @@ def create_vector_figure( v: NDArray, subsample: int = 5, title: str = "Vector Field", - colorscale: str = "Viridis", + colorscale: str = UNSET, height: int = 500, width: int = 700, ) -> go.Figure: @@ -252,6 +258,8 @@ def create_vector_figure( Returns: Plotly Figure object. """ + colorscale = resolve(colorscale, "colorscale") + X, Y = np.meshgrid(x, y) X_sub = X[::subsample, ::subsample] Y_sub = Y[::subsample, ::subsample] @@ -318,7 +326,7 @@ def create_surface_figure( y: NDArray, field: NDArray, title: str = "3D Surface", - colorscale: str = "Viridis", + colorscale: str = UNSET, height: int = 600, width: int = 800, ) -> go.Figure: @@ -336,6 +344,8 @@ def create_surface_figure( Returns: Plotly Figure object. """ + colorscale = resolve(colorscale, "colorscale") + fig = go.Figure( data=go.Surface( z=field, @@ -533,6 +543,8 @@ def create_dashboard_figure( Returns: Plotly Figure object with 6 panels. """ + defaults = get_defaults() + x, y = frame.x, frame.y u = frame.fields.get("u", np.zeros((len(y), len(x)))) v = frame.fields.get("v", np.zeros((len(y), len(x)))) @@ -562,21 +574,29 @@ def create_dashboard_figure( # Row 1: Velocity magnitude fig.add_trace( - go.Heatmap(z=vel_mag, x=x, y=y, colorscale="Viridis", showscale=True), + go.Heatmap(z=vel_mag, x=x, y=y, colorscale=defaults.colorscale, showscale=True), row=1, col=1, ) # Row 1: Pressure fig.add_trace( - go.Heatmap(z=p, x=x, y=y, colorscale="RdBu", showscale=True), + go.Heatmap( + z=p, x=x, y=y, colorscale=defaults.diverging_colorscale, showscale=True + ), row=1, col=2, ) # Row 1: Vorticity fig.add_trace( - go.Heatmap(z=vorticity, x=x, y=y, colorscale="RdBu", showscale=True), + go.Heatmap( + z=vorticity, + x=x, + y=y, + colorscale=defaults.diverging_colorscale, + showscale=True, + ), row=1, col=3, ) @@ -596,7 +616,7 @@ def create_dashboard_figure( marker=dict( size=8, color=speed_sub.flatten(), - colorscale="Viridis", + colorscale=defaults.colorscale, showscale=False, ), ), @@ -613,7 +633,9 @@ def create_dashboard_figure( # Row 2: 3D Surface fig.add_trace( - go.Surface(z=vel_mag, x=x, y=y, colorscale="Viridis", showscale=False), + go.Surface( + z=vel_mag, x=x, y=y, colorscale=defaults.colorscale, showscale=False + ), row=2, col=3, ) @@ -648,6 +670,8 @@ def create_animated_dashboard( if not frames.frames: raise ValueError("No frames provided") + defaults = get_defaults() + # Create subplot structure fig = sp.make_subplots( rows=3, @@ -693,15 +717,27 @@ def create_animated_dashboard( # 1. Velocity Magnitude frame_data.append( - go.Heatmap(z=vel_mag, x=x, y=y, colorscale="Viridis", showscale=False) + go.Heatmap( + z=vel_mag, x=x, y=y, colorscale=defaults.colorscale, showscale=False + ) ) # 2. Pressure - frame_data.append(go.Heatmap(z=p, x=x, y=y, colorscale="RdBu", showscale=False)) + frame_data.append( + go.Heatmap( + z=p, x=x, y=y, colorscale=defaults.diverging_colorscale, showscale=False + ) + ) # 3. Vorticity frame_data.append( - go.Heatmap(z=vorticity, x=x, y=y, colorscale="RdBu", showscale=False) + go.Heatmap( + z=vorticity, + x=x, + y=y, + colorscale=defaults.diverging_colorscale, + showscale=False, + ) ) # 4. Vector Field @@ -718,7 +754,7 @@ def create_animated_dashboard( marker=dict( size=8, color=speed_sub.flatten(), - colorscale="Viridis", + colorscale=defaults.colorscale, showscale=False, ), showlegend=False, @@ -735,7 +771,7 @@ def create_animated_dashboard( marker=dict( size=2, color=vel_mag.flatten(), - colorscale="Viridis", + colorscale=defaults.colorscale, showscale=False, ), showlegend=False, @@ -750,7 +786,12 @@ def create_animated_dashboard( # 7. Combined frame_data.append( go.Heatmap( - z=vel_mag, x=x, y=y, colorscale="Viridis", opacity=0.7, showscale=False + z=vel_mag, + x=x, + y=y, + colorscale=defaults.colorscale, + opacity=0.7, + showscale=False, ) ) @@ -766,7 +807,9 @@ def create_animated_dashboard( # 9. 3D Surface frame_data.append( - go.Surface(z=vel_mag, x=x, y=y, colorscale="Viridis", showscale=False) + go.Surface( + z=vel_mag, x=x, y=y, colorscale=defaults.colorscale, showscale=False + ) ) plotly_frames.append(go.Frame(data=frame_data, name=str(frame_idx))) @@ -784,9 +827,19 @@ def create_animated_dashboard( T = first_frame.fields.get("T", np.ones_like(X) * 300) # Row 1 - fig.add_trace(go.Heatmap(z=vel_mag, x=x, y=y, colorscale="Viridis"), row=1, col=1) - fig.add_trace(go.Heatmap(z=p, x=x, y=y, colorscale="RdBu"), row=1, col=2) - fig.add_trace(go.Heatmap(z=vorticity, x=x, y=y, colorscale="RdBu"), row=1, col=3) + fig.add_trace( + go.Heatmap(z=vel_mag, x=x, y=y, colorscale=defaults.colorscale), row=1, col=1 + ) + fig.add_trace( + go.Heatmap(z=p, x=x, y=y, colorscale=defaults.diverging_colorscale), + row=1, + col=2, + ) + fig.add_trace( + go.Heatmap(z=vorticity, x=x, y=y, colorscale=defaults.diverging_colorscale), + row=1, + col=3, + ) # Row 2 skip = max(1, min(len(x), len(y)) // 10) @@ -799,7 +852,9 @@ def create_animated_dashboard( x=X_sub.flatten(), y=Y_sub.flatten(), mode="markers", - marker=dict(size=8, color=speed_sub.flatten(), colorscale="Viridis"), + marker=dict( + size=8, color=speed_sub.flatten(), colorscale=defaults.colorscale + ), ), row=2, col=1, @@ -811,7 +866,9 @@ def create_animated_dashboard( y=Y.flatten(), z=vel_mag.flatten(), mode="markers", - marker=dict(size=2, color=vel_mag.flatten(), colorscale="Viridis"), + marker=dict( + size=2, color=vel_mag.flatten(), colorscale=defaults.colorscale + ), ), row=2, col=2, @@ -821,7 +878,7 @@ def create_animated_dashboard( # Row 3 fig.add_trace( - go.Heatmap(z=vel_mag, x=x, y=y, colorscale="Viridis", opacity=0.7), + go.Heatmap(z=vel_mag, x=x, y=y, colorscale=defaults.colorscale, opacity=0.7), row=3, col=1, ) @@ -833,7 +890,7 @@ def create_animated_dashboard( ) fig.add_trace( - go.Surface(z=vel_mag, x=x, y=y, colorscale="Viridis"), + go.Surface(z=vel_mag, x=x, y=y, colorscale=defaults.colorscale), row=3, col=3, ) From 054633e80e8bfcfcf4dcd9e2eb823fa9d35621ae Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 15:26:37 +0200 Subject: [PATCH 07/12] test: Add integration tests verifying plotting respects global defaults Tests that set_defaults(), plot_context(), and per-field cmap resolution (cmap, diverging_cmap, sequential_cmap) are correctly picked up by plotting functions. --- tests/test_defaults.py | 67 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/test_defaults.py b/tests/test_defaults.py index 290cd10..7d8f368 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -1,7 +1,12 @@ """Tests for cfd_viz.defaults module.""" +import matplotlib +import matplotlib.pyplot as plt +import numpy as np import pytest +matplotlib.use("Agg") + from cfd_viz.defaults import ( UNSET, PlotDefaults, @@ -226,3 +231,65 @@ def test_pyproject_without_section_returns_false(self, tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) (tmp_path / "pyproject.toml").write_text("[tool.other]\nfoo = 1\n") assert load_config_file() is False + + +class TestPlottingIntegration: + """Integration tests: plotting functions respect global defaults.""" + + def _make_grid(self): + x = np.linspace(0, 1, 10) + y = np.linspace(0, 1, 10) + X, Y = np.meshgrid(x, y) + field = np.sin(X) * np.cos(Y) + return X, Y, field + + def test_contour_uses_global_cmap(self): + from cfd_viz.plotting.fields import plot_contour_field + + X, Y, field = self._make_grid() + set_defaults(cmap="hot") + ax = plot_contour_field(X, Y, field) + # contourf stores cmap on the QuadContourSet + cmap_name = ax.collections[0].get_cmap().name + assert cmap_name == "hot" + plt.close("all") + + def test_explicit_cmap_overrides_global(self): + from cfd_viz.plotting.fields import plot_contour_field + + X, Y, field = self._make_grid() + set_defaults(cmap="hot") + ax = plot_contour_field(X, Y, field, cmap="coolwarm") + cmap_name = ax.collections[0].get_cmap().name + assert cmap_name == "coolwarm" + plt.close("all") + + def test_plot_context_affects_plotting(self): + from cfd_viz.plotting.fields import plot_contour_field + + X, Y, field = self._make_grid() + with plot_context(cmap="magma"): + ax = plot_contour_field(X, Y, field) + cmap_name = ax.collections[0].get_cmap().name + assert cmap_name == "magma" + plt.close("all") + + def test_vorticity_uses_diverging_cmap(self): + from cfd_viz.plotting.fields import plot_vorticity_field + + X, Y, field = self._make_grid() + set_defaults(diverging_cmap="PuOr") + ax = plot_vorticity_field(X, Y, field) + cmap_name = ax.collections[0].get_cmap().name + assert cmap_name == "PuOr" + plt.close("all") + + def test_pressure_uses_sequential_cmap(self): + from cfd_viz.plotting.fields import plot_pressure_field + + X, Y, field = self._make_grid() + set_defaults(sequential_cmap="inferno") + ax = plot_pressure_field(X, Y, field) + cmap_name = ax.collections[0].get_cmap().name + assert cmap_name == "inferno" + plt.close("all") From 183e56eb5a8b1cceaf639a9a59fc60ab86e268f1 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 16:22:31 +0200 Subject: [PATCH 08/12] fix: Set matplotlib backend before importing pyplot in test_defaults Calling matplotlib.use("Agg") after importing pyplot can be ineffective or raise warnings depending on the backend state. --- tests/test_defaults.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_defaults.py b/tests/test_defaults.py index 7d8f368..75f853c 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -1,12 +1,13 @@ """Tests for cfd_viz.defaults module.""" import matplotlib + +matplotlib.use("Agg") + import matplotlib.pyplot as plt import numpy as np import pytest -matplotlib.use("Agg") - from cfd_viz.defaults import ( UNSET, PlotDefaults, From f93c089c25dbb1102b973b95a4575ce377872175 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 16:29:40 +0200 Subject: [PATCH 09/12] fix: Use contextvars for thread-safe plot_context() plot_context() now uses a contextvars.ContextVar for context-local overrides instead of mutating and restoring the global _defaults. This prevents concurrent set_defaults() calls from being silently overwritten when the context exits. --- cfd_viz/defaults.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/cfd_viz/defaults.py b/cfd_viz/defaults.py index 54799ee..f7bdaa0 100644 --- a/cfd_viz/defaults.py +++ b/cfd_viz/defaults.py @@ -17,6 +17,7 @@ from __future__ import annotations +import contextvars import copy import threading from contextlib import contextmanager @@ -73,12 +74,22 @@ class PlotDefaults: diverging_colorscale: str = "RdBu" -_lock = threading.Lock() +_lock = threading.RLock() _defaults = PlotDefaults() +_context_override: contextvars.ContextVar[PlotDefaults | None] = contextvars.ContextVar( + "_context_override", default=None +) def get_defaults() -> PlotDefaults: - """Return a copy of the current global defaults.""" + """Return a copy of the current global defaults. + + If called inside a ``plot_context()`` block, returns the + context-local overrides instead of the global defaults. + """ + override = _context_override.get() + if override is not None: + return copy.copy(override) with _lock: return copy.copy(_defaults) @@ -148,15 +159,22 @@ def plot_context(**kwargs: Any) -> Iterator[PlotDefaults]: ... quick_plot(u, v, nx, ny) # uses coolwarm >>> # defaults restored automatically """ - global _defaults # noqa: PLW0603 - with _lock: - saved = copy.copy(_defaults) + valid_fields = {f.name for f in dc_fields(PlotDefaults)} + unknown = set(kwargs) - valid_fields + if unknown: + raise TypeError( + f"Unknown defaults: {', '.join(sorted(unknown))}. " + f"Valid fields: {', '.join(sorted(valid_fields))}" + ) + # Start from current effective defaults (context-local or global) + base = get_defaults() + for key, value in kwargs.items(): + setattr(base, key, value) + token = _context_override.set(base) try: - set_defaults(**kwargs) - yield get_defaults() + yield copy.copy(base) finally: - with _lock: - _defaults = saved + _context_override.reset(token) def load_config_file(path: str | None = None) -> bool: From 3a73994c6635f5d8078910e25e23b353eb829596 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 16:38:00 +0200 Subject: [PATCH 10/12] fix: Replace remaining hardcoded levels=20 with defaults.levels in export.py The initial migration missed inline contourf levels in export_frame_to_image and export_analysis_frame, so set_defaults(levels=...) had no effect on exported frames. --- cfd_viz/animation/export.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/cfd_viz/animation/export.py b/cfd_viz/animation/export.py index 416a8d9..addffc1 100644 --- a/cfd_viz/animation/export.py +++ b/cfd_viz/animation/export.py @@ -60,7 +60,9 @@ def export_frame_to_image( # Panel 1: Velocity magnitude contours if velocity_mag is not None: - im1 = axes[0].contourf(X, Y, velocity_mag, levels=20, cmap=defaults.cmap) + im1 = axes[0].contourf( + X, Y, velocity_mag, levels=defaults.levels, cmap=defaults.cmap + ) axes[0].set_title("Velocity Magnitude") axes[0].axis("equal") plt.colorbar(im1, ax=axes[0]) @@ -103,7 +105,9 @@ def export_frame_to_image( except ValueError: # Streamplot can fail with certain grid configurations if velocity_mag is not None: - axes[2].contourf(X, Y, velocity_mag, levels=20, cmap=defaults.cmap) + axes[2].contourf( + X, Y, velocity_mag, levels=defaults.levels, cmap=defaults.cmap + ) axes[2].set_title("Flow Streamlines") axes[2].axis("equal") else: @@ -111,13 +115,15 @@ def export_frame_to_image( # Panel 4: Pressure or combined view if include_pressure and p is not None: - im4 = axes[3].contourf(X, Y, p, levels=20, cmap=defaults.diverging_cmap) + im4 = axes[3].contourf( + X, Y, p, levels=defaults.levels, cmap=defaults.diverging_cmap + ) axes[3].set_title("Pressure Field") axes[3].axis("equal") plt.colorbar(im4, ax=axes[3]) elif velocity_mag is not None and u is not None and v is not None: im4 = axes[3].contourf( - X, Y, velocity_mag, levels=20, cmap=defaults.cmap, alpha=0.7 + X, Y, velocity_mag, levels=defaults.levels, cmap=defaults.cmap, alpha=0.7 ) axes[3].streamplot(X, Y, u, v, color="white", density=1, linewidth=0.8) axes[3].set_title("Combined: Magnitude + Streamlines") @@ -219,7 +225,9 @@ def create_comprehensive_frame_figure( velocity_cmap = create_velocity_colormap() # Velocity magnitude - im0 = axes[0].contourf(X, Y, velocity_mag, levels=20, cmap=velocity_cmap) + im0 = axes[0].contourf( + X, Y, velocity_mag, levels=defaults.levels, cmap=velocity_cmap + ) axes[0].set_title("Velocity Magnitude", fontweight="bold") axes[0].set_xlabel("X") axes[0].set_ylabel("Y") @@ -227,7 +235,9 @@ def create_comprehensive_frame_figure( plt.colorbar(im0, ax=axes[0]) # Pressure - im1 = axes[1].contourf(X, Y, p, levels=20, cmap=defaults.diverging_cmap) + im1 = axes[1].contourf( + X, Y, p, levels=defaults.levels, cmap=defaults.diverging_cmap + ) contours = axes[1].contour(X, Y, p, levels=8, colors="black", linewidths=0.5) axes[1].clabel(contours, inline=True, fontsize=8) axes[1].set_title("Pressure Field", fontweight="bold") @@ -237,7 +247,7 @@ def create_comprehensive_frame_figure( plt.colorbar(im1, ax=axes[1]) # Vorticity - im2 = axes[2].contourf(X, Y, vorticity, levels=20, cmap="seismic") + im2 = axes[2].contourf(X, Y, vorticity, levels=defaults.levels, cmap="seismic") axes[2].set_title("Vorticity", fontweight="bold") axes[2].set_xlabel("X") axes[2].set_ylabel("Y") From bc5e20c4fe82489f9db104b8aa8ebf1bbc61cab7 Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 17:09:20 +0200 Subject: [PATCH 11/12] fix: Add tomli conditional dependency for Python <3.11 load_config_file() needs a TOML parser but tomllib is only available in Python 3.11+. Add tomli as a fallback for older Python versions. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 35a2f2d..48427bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "watchdog", "packaging", "plotly", + "tomli>=1.0; python_version < '3.11'", ] [project.optional-dependencies] From d0b64f03d61ac056894fda6791aab39e92b346ff Mon Sep 17 00:00:00 2001 From: shaia Date: Sat, 7 Mar 2026 17:12:57 +0200 Subject: [PATCH 12/12] fix: Use defaults.levels instead of hard-coded levels=20 in combined analysis panel --- cfd_viz/animation/export.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cfd_viz/animation/export.py b/cfd_viz/animation/export.py index addffc1..531a47d 100644 --- a/cfd_viz/animation/export.py +++ b/cfd_viz/animation/export.py @@ -284,7 +284,9 @@ def create_comprehensive_frame_figure( axes[4].set_aspect("equal") # Combined analysis - axes[5].contourf(X, Y, velocity_mag, levels=20, cmap=velocity_cmap, alpha=0.8) + axes[5].contourf( + X, Y, velocity_mag, levels=defaults.levels, cmap=velocity_cmap, alpha=0.8 + ) axes[5].contour(X, Y, p, levels=6, colors="white", linewidths=1.5) if np.any(vorticity != 0): vort_threshold = np.percentile(np.abs(vorticity), 85)