diff --git a/xrspatial/geotiff/_pam.py b/xrspatial/geotiff/_pam.py index 6b52b9894..0a257b101 100644 --- a/xrspatial/geotiff/_pam.py +++ b/xrspatial/geotiff/_pam.py @@ -105,6 +105,33 @@ def write_pam_sidecar(path, category_names, category_colors=None): return out +def build_stats_pam_xml(stats): + """Build a PAM ``.aux.xml`` document carrying band statistics. + + *stats* maps GDAL ``STATISTICS_*`` keys to numeric values. GDAL tools and + QGIS read these from the band ```` to drive a default min/max + stretch on a continuous raster, the way :func:`build_pam_xml` drives class + coloring on a categorical one. + """ + lines = ['', ' ', ' '] + for key, value in stats.items(): + text = escape('{:.10g}'.format(float(value))) + lines.append(f' {text}') + lines.append(' ') + lines.append(' ') + lines.append('') + return '\n'.join(lines) + '\n' + + +def write_stats_pam_sidecar(path, stats): + """Write the statistics PAM sidecar for *path* and return its path.""" + xml = build_stats_pam_xml(stats) + out = sidecar_path(path) + with open(out, 'w', encoding='utf-8') as fh: + fh.write(xml) + return out + + def read_pam_sidecar(path): """Read ``category_names`` / ``category_colors`` from *path*'s sidecar. diff --git a/xrspatial/geotiff/_symbology.py b/xrspatial/geotiff/_symbology.py new file mode 100644 index 000000000..9fe11dcaa --- /dev/null +++ b/xrspatial/geotiff/_symbology.py @@ -0,0 +1,272 @@ +"""Symbology sidecars for continuous (non-categorical) rasters. + +A continuous single-band GeoTIFF opens in QGIS as a flat grayscale stretch +unless QGIS finds styling next to the file. Two sidecars give it a sensible +default color ramp on open: + +* a QGIS ``.qml`` style file (``.qml``) with a ``singlebandpseudocolor`` + renderer -- this is what makes QGIS draw real colors, and +* band statistics (min/max/mean/stddev) in the PAM ``.aux.xml`` sidecar + (written via :mod:`._pam`), which GDAL tools and QGIS use for a min/max + stretch even when the ``.qml`` is ignored. + +This is the continuous-raster counterpart to the categorical RAT sidecar in +:mod:`._pam`. Color ramps are hardcoded control points (no matplotlib runtime +dependency); ``viridis`` is the default. +""" +from __future__ import annotations + +import math +import os + +import numpy as np + +from .. import utils + +# Named color ramps, each sampled at 9 evenly spaced control points +# (matplotlib colormaps, committed as static constants so we carry no +# matplotlib runtime dependency). Each entry is (position 0..1, (r, g, b)). +_COLOR_RAMPS = { + 'viridis': [(0.0, (68, 1, 84)), (0.125, (71, 45, 123)), (0.25, (59, 82, 139)), + (0.375, (44, 114, 142)), (0.5, (33, 145, 140)), + (0.625, (40, 174, 128)), (0.75, (94, 201, 98)), + (0.875, (173, 220, 48)), (1.0, (253, 231, 37))], + 'plasma': [(0.0, (13, 8, 135)), (0.125, (76, 2, 161)), (0.25, (126, 3, 168)), + (0.375, (170, 35, 149)), (0.5, (204, 71, 120)), + (0.625, (230, 108, 92)), (0.75, (248, 149, 64)), + (0.875, (253, 197, 39)), (1.0, (240, 249, 33))], + 'magma': [(0.0, (0, 0, 4)), (0.125, (29, 17, 71)), (0.25, (81, 18, 124)), + (0.375, (131, 38, 129)), (0.5, (183, 55, 121)), + (0.625, (231, 82, 99)), (0.75, (252, 137, 97)), + (0.875, (254, 196, 136)), (1.0, (252, 253, 191))], + 'inferno': [(0.0, (0, 0, 4)), (0.125, (33, 12, 74)), (0.25, (87, 16, 110)), + (0.375, (138, 34, 106)), (0.5, (188, 55, 84)), + (0.625, (228, 90, 49)), (0.75, (249, 142, 9)), + (0.875, (249, 203, 53)), (1.0, (252, 255, 164))], + 'cividis': [(0.0, (0, 34, 78)), (0.125, (26, 56, 111)), (0.25, (67, 78, 108)), + (0.375, (97, 101, 111)), (0.5, (125, 124, 120)), + (0.625, (155, 148, 118)), (0.75, (188, 174, 108)), + (0.875, (222, 201, 88)), (1.0, (254, 232, 56))], + 'greys': [(0.0, (255, 255, 255)), (0.125, (240, 240, 240)), + (0.25, (217, 217, 217)), (0.375, (189, 189, 189)), + (0.5, (149, 149, 149)), (0.625, (114, 114, 114)), + (0.75, (81, 81, 81)), (0.875, (36, 36, 36)), (1.0, (0, 0, 0))], + 'spectral': [(0.0, (158, 1, 66)), (0.125, (221, 74, 76)), (0.25, (249, 142, 82)), + (0.375, (254, 212, 129)), (0.5, (255, 255, 190)), + (0.625, (214, 238, 155)), (0.75, (134, 207, 165)), + (0.875, (61, 149, 184)), (1.0, (94, 79, 162))], + 'terrain': [(0.0, (51, 51, 153)), (0.125, (8, 136, 238)), (0.25, (1, 204, 102)), + (0.375, (129, 230, 128)), (0.5, (254, 254, 152)), + (0.625, (190, 172, 118)), (0.75, (129, 94, 86)), + (0.875, (193, 176, 172)), (1.0, (255, 255, 255))], +} + +_DEFAULT_RAMP = 'viridis' + + +def resolve_ramp(name): + """Return the control-point list for color ramp *name*. + + ``True`` selects the default ramp (``viridis``). Unknown names raise + ``ValueError`` listing the available ramps so a typo fails fast, before + any bytes are written. + """ + if name is True: + name = _DEFAULT_RAMP + key = str(name).lower() + try: + return _COLOR_RAMPS[key] + except KeyError: + raise ValueError( + f"unknown color_ramp {name!r}; choose from {sorted(_COLOR_RAMPS)}") + + +def qml_path(path: str) -> str: + """Return the QGIS style sidecar path QGIS auto-loads for *path*. + + QGIS builds its default-style path from ``completeBaseName + '.qml'``: it + *replaces* the extension (``dem.tif`` -> ``dem.qml``), unlike the PAM + ``.aux.xml`` sidecar which is *appended* (``dem.tif.aux.xml``). Writing the + appended form here would leave the style silently unloaded. + """ + return os.path.splitext(path)[0] + '.qml' + + +def _num(value) -> str: + """Format a number for QML/PAM with enough precision to round-trip.""" + return '{:.10g}'.format(float(value)) + + +def build_qml(stops, vmin, vmax) -> str: + """Build a QGIS ``.qml`` singlebandpseudocolor style for one band. + + *stops* is a ``[(pos, (r, g, b)), ...]`` ramp; each stop becomes a + ```` item at ``vmin + pos * (vmax - vmin)`` (the item list + is what QGIS actually renders). The sibling ```` block lets the + QGIS style dialog show the ramp as an editable gradient. + """ + vmin = float(vmin) + vmax = float(vmax) + span = vmax - vmin + + items = [] + for pos, (r, g, b) in stops: + value = vmin + pos * span + color = '#%02x%02x%02x' % (r, g, b) + items.append( + f' ') + item_xml = '\n'.join(items) + + color1 = '%d,%d,%d,255' % stops[0][1] + color2 = '%d,%d,%d,255' % stops[-1][1] + interior = ':'.join( + '%s;%d,%d,%d,255' % (_num(pos), r, g, b) + for pos, (r, g, b) in stops[1:-1]) + + return f""" + + + + + + + + +{item_xml} + + + + + + + + 0 + +""" + + +def _is_single_band(data) -> bool: + """True when *data* is a single-band raster (2D, or 3D with one band).""" + if data.ndim == 2: + return True + if data.ndim == 3: + from ._coords import _BAND_DIM_NAMES + for dim, size in zip(data.dims, data.shape): + if dim in _BAND_DIM_NAMES: + return size == 1 + return False + + +def _finite_stats(data, nodata=None): + """Return ``(min, max, mean, stddev)`` over finite, non-nodata values. + + Population stddev (``ddof=0``), matching GDAL's ``STATISTICS_STDDEV``. + Returns ``None`` when no finite values remain. Backend-aware (numpy, cupy, + dask+numpy, dask+cupy); the dask path fuses the reductions into a single + ``dask.compute`` so the source graph is read once, not four times. + """ + arr = getattr(data, 'data', data) + has_nodata = nodata is not None and not ( + isinstance(nodata, float) and math.isnan(nodata)) + + if utils.has_dask_array(): + import dask.array as dask_array + if isinstance(arr, dask_array.Array): + return _dask_finite_stats(arr, nodata if has_nodata else None) + return _eager_finite_stats(arr, nodata if has_nodata else None) + + +def _eager_finite_stats(arr, nodata): + """``_finite_stats`` for an in-memory numpy or cupy array.""" + xp = np + if utils.is_cupy_array(arr): + import cupy + xp = cupy + + mask = xp.isfinite(arr) + if nodata is not None: + mask = mask & (arr != nodata) + vals = arr[mask] + if vals.size == 0: + return None + return (utils._to_float_scalar(vals.min()), + utils._to_float_scalar(vals.max()), + utils._to_float_scalar(vals.mean()), + utils._to_float_scalar(vals.std())) + + +def _dask_finite_stats(arr, nodata): + """``_finite_stats`` for a dask array (numpy- or cupy-backed).""" + import dask + import dask.array as dask_array + + mask = dask_array.isfinite(arr) + if nodata is not None: + mask = mask & (arr != nodata) + # ``where`` keeps the array shape and NaNs out excluded cells so the nan* + # reductions ignore them; integer inputs upcast to float, which is fine + # for statistics. + vals = dask_array.where(mask, arr, np.nan) + count, mn, mx, mean, std = dask.compute( + mask.sum(), + dask_array.nanmin(vals), + dask_array.nanmax(vals), + dask_array.nanmean(vals), + dask_array.nanstd(vals), + ) + if utils._to_float_scalar(count) == 0: + return None + return (utils._to_float_scalar(mn), utils._to_float_scalar(mx), + utils._to_float_scalar(mean), utils._to_float_scalar(std)) + + +def write_symbology_sidecars(path, data, *, stops, nodata=None, ramp_range=None): + """Write continuous-raster symbology sidecars next to *path*. + + Writes band statistics into the PAM ``.aux.xml`` and, when the data range + is non-degenerate, a QGIS ``.qml`` color-ramp style. No-op for multiband + arrays or arrays with no finite data. + + When *ramp_range* ``(vmin, vmax)`` is given it sets the ramp/stretch bounds + directly and skips the statistics reduction -- the escape hatch for large + dask graphs -- so only ``STATISTICS_MINIMUM`` / ``STATISTICS_MAXIMUM`` are + written (mean/stddev would need the pass it avoids). + """ + from . import _pam + + if not _is_single_band(data): + return + + if ramp_range is not None: + vmin, vmax = float(ramp_range[0]), float(ramp_range[1]) + stats = {'STATISTICS_MINIMUM': vmin, 'STATISTICS_MAXIMUM': vmax} + else: + result = _finite_stats(data, nodata) + if result is None: + return + vmin, vmax, vmean, vstd = result + stats = { + 'STATISTICS_MINIMUM': vmin, + 'STATISTICS_MAXIMUM': vmax, + 'STATISTICS_MEAN': vmean, + 'STATISTICS_STDDEV': vstd, + } + + _pam.write_stats_pam_sidecar(path, stats) + + # A constant raster (vmin == vmax) has no range to ramp across, so the + # stats stretch is the useful part and the QML would be a degenerate + # single-stop ramp -- skip it. + if vmax > vmin: + with open(qml_path(path), 'w', encoding='utf-8') as fh: + fh.write(build_qml(stops, vmin, vmax)) diff --git a/xrspatial/geotiff/_writers/eager.py b/xrspatial/geotiff/_writers/eager.py index 1c3c3a5e6..a2129d412 100644 --- a/xrspatial/geotiff/_writers/eager.py +++ b/xrspatial/geotiff/_writers/eager.py @@ -63,7 +63,10 @@ def to_geotiff(data: xr.DataArray | np.ndarray, allow_experimental_codecs: bool = False, allow_unparseable_crs: bool = False, drop_rotation: bool = False, - pack: bool = False) -> str | BinaryIO: + pack: bool = False, + color_ramp: str | bool | None = None, + color_ramp_range: tuple[float, float] | None = None, + ) -> str | BinaryIO: """Write data as a GeoTIFF or Cloud Optimized GeoTIFF. Release-contract tier (see @@ -398,6 +401,30 @@ def to_geotiff(data: xr.DataArray | np.ndarray, (issue #3260). The write itself stays atomic (temp file plus rename), so no partial output is left at the destination path. + color_ramp : str, bool, or None, default None + [advanced] Write best-practice symbology sidecars so a continuous + single-band raster opens in QGIS with a color ramp instead of a flat + grayscale stretch. Pass a ramp name (``'viridis'`` -- the default -- + ``'plasma'``, ``'magma'``, ``'inferno'``, ``'cividis'``, ``'greys'``, + ``'spectral'``, ``'terrain'``) or ``True`` for viridis; an unknown name + raises ``ValueError``. ``'greys'`` follows matplotlib's light-to-dark + orientation (low values render light). Two sidecars are written: a + QGIS ``.qml`` style (``.qml``) with a ``singlebandpseudocolor`` + renderer, and ``STATISTICS_MINIMUM/MAXIMUM/MEAN/STDDEV`` in the PAM + ``.aux.xml``. No-op for a categorical raster (one with + ``attrs['category_names']`` -- those get the RAT sidecar instead), a + multiband array, a file-like destination, or data with no finite + values. Computing the statistics is a separate reduction pass over the + data; for a dask source that means reading the graph once more (see + ``color_ramp_range`` to skip it). Ignored when ``pack=True``, whose + on-disk packed values would not match a ramp built from the logical + values. + color_ramp_range : tuple of (float, float) or None, default None + [advanced] Explicit ``(min, max)`` for the ``color_ramp`` stretch. + Skips the statistics reduction -- useful for large dask graphs -- so + only ``STATISTICS_MINIMUM`` / ``STATISTICS_MAXIMUM`` are written + (mean/stddev need the pass it avoids). Ignored when ``color_ramp`` is + not set. Returns ------- @@ -435,21 +462,43 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path = _coerce_path(path) - # Categorical rasters carry their value->label map in attrs. GDAL/QGIS - # only read category names and colors from a PAM ``.aux.xml`` - # sidecar (an embedded RAT is ignored), so capture the labels now and - # emit the sidecar next to the file on the way out. File-like - # destinations have no path to anchor a sidecar, so skip them. + # Categorical rasters carry their value->label map in attrs; continuous + # rasters opt into a color ramp via ``color_ramp``. QGIS/GDAL read both + # kinds of styling from sidecars next to the file (a PAM ``.aux.xml`` RAT + # for categories, a ``.qml`` style plus PAM statistics for a continuous + # ramp), so capture what we need now and emit the sidecars on the way out. + # Categorical wins when both are present, and the two never collide. File- + # like destinations have no path to anchor a sidecar, so skip them. _cat_names = None _cat_colors = None + _sym_data = None + _sym_stops = None + _sym_nodata = None if isinstance(path, str) and isinstance(data, xr.DataArray): _cat_names = data.attrs.get('category_names') _cat_colors = data.attrs.get('category_colors') - - def _write_category_sidecar(): + # ``pack`` rewrites ``data`` to on-disk packed values below, so the + # symbology statistics (taken from the logical values here) would + # describe a different range than the stored pixels. Skip symbology + # when packing rather than emit a mismatched ramp. + if color_ramp and not _cat_names and not pack: + from .._symbology import resolve_ramp + # Validate the ramp name now so a typo fails before any bytes. + _sym_stops = resolve_ramp(color_ramp) + _sym_data = data + _sym_nodata = (nodata if nodata is not None + else _resolve_nodata_attr(data.attrs)) + + def _write_sidecars(): if _cat_names: from .._pam import write_pam_sidecar write_pam_sidecar(path, _cat_names, _cat_colors) + return + if _sym_stops is not None: + from .._symbology import write_symbology_sidecars + write_symbology_sidecars( + path, _sym_data, stops=_sym_stops, + nodata=_sym_nodata, ramp_range=color_ramp_range) # Reject bool / np.bool_ nodata up front. ``bool`` is a subclass of # ``int`` in Python, so a typo like ``nodata=True`` slips past every @@ -828,7 +877,7 @@ def _write_category_sidecar(): allow_unparseable_crs=allow_unparseable_crs, allow_internal_only_jpeg=allow_internal_only_jpeg, drop_rotation=drop_rotation) - _write_category_sidecar() + _write_sidecars() return path # Dispatch to _write_geotiff_gpu when GPU was selected (explicit @@ -877,7 +926,7 @@ def _write_category_sidecar(): allow_unparseable_crs=allow_unparseable_crs, drop_rotation=drop_rotation, ) - _write_category_sidecar() + _write_sidecars() return path except ImportError as e: # ``_write_geotiff_gpu`` raises ImportError when cupy itself @@ -1060,7 +1109,7 @@ def _write_category_sidecar(): allow_internal_only_jpeg=allow_internal_only_jpeg, allow_unparseable_crs=allow_unparseable_crs, ) - _write_category_sidecar() + _write_sidecars() return path # Eager compute (numpy, CuPy, or dask+COG) @@ -1156,7 +1205,7 @@ def _write_category_sidecar(): allow_internal_only_jpeg=allow_internal_only_jpeg, allow_unparseable_crs=allow_unparseable_crs, ) - _write_category_sidecar() + _write_sidecars() return path diff --git a/xrspatial/geotiff/_writers/gpu.py b/xrspatial/geotiff/_writers/gpu.py index 360c60273..597211563 100644 --- a/xrspatial/geotiff/_writers/gpu.py +++ b/xrspatial/geotiff/_writers/gpu.py @@ -111,6 +111,8 @@ def _write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, allow_unparseable_crs: bool = False, drop_rotation: bool = False, pack: bool = False, + color_ramp: str | bool | None = None, + color_ramp_range: tuple[float, float] | None = None, ) -> str | BinaryIO: """Write a CuPy-backed DataArray as a GeoTIFF with GPU compression. @@ -310,6 +312,13 @@ def _write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, parity with ``to_geotiff``, which applies the ``pack`` re-pack transform before dispatching here. See ``to_geotiff`` for the behaviour. + color_ramp : str, bool, or None, default None + [advanced] No-op on the GPU writer: it exists for signature parity + with ``to_geotiff``, which writes the symbology sidecars from the + shared sidecar closure after this writer returns. See ``to_geotiff``. + color_ramp_range : tuple of (float, float) or None, default None + [advanced] No-op on the GPU writer; signature parity with + ``to_geotiff``. See ``to_geotiff``. Returns ------- @@ -512,8 +521,10 @@ def _write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, f"method, got {type(path).__name__}") # ``pack`` is a no-op here: ``to_geotiff`` applies the # re-pack transform before dispatching to this writer, so the kwarg - # only needs to exist for signature parity (#3064). - del pack + # only needs to exist for signature parity (#3064). ``color_ramp`` / + # ``color_ramp_range`` are likewise wrapper-only: the symbology sidecars + # are written by ``to_geotiff``'s sidecar closure, not here. + del pack, color_ramp, color_ramp_range try: import cupy except ImportError: diff --git a/xrspatial/geotiff/tests/write/test_symbology_sidecar_3537.py b/xrspatial/geotiff/tests/write/test_symbology_sidecar_3537.py new file mode 100644 index 000000000..7aea628f6 --- /dev/null +++ b/xrspatial/geotiff/tests/write/test_symbology_sidecar_3537.py @@ -0,0 +1,286 @@ +"""Continuous-raster symbology sidecars across write backends (#3537). + +``to_geotiff(..., color_ramp=...)`` writes two sidecars next to a continuous +single-band raster so QGIS opens it with a color ramp: a QGIS ``.qml`` +style and ``STATISTICS_*`` in the PAM ``.aux.xml``. The emit is wired into +the shared ``_write_sidecars()`` closure that every write path calls, so these +tests drive the numpy, dask, GPU, and dask+GPU backends to keep the emit (and +the cross-backend statistics) from regressing on any one branch. Unit tests +cover the building blocks (``_finite_stats``, ``build_qml``, ``qml_path``, +``resolve_ramp``, ``build_stats_pam_xml``) and the gating rules. +""" +import io +import os +import xml.etree.ElementTree as ET + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._pam import build_stats_pam_xml +from xrspatial.geotiff._symbology import ( + _finite_stats, build_qml, qml_path, resolve_ramp) + +from .._helpers.markers import requires_gpu + +pytest.importorskip("tifffile") + +# A smooth gradient with one NaN hole, so finite stats must skip the NaN. +_BASE = np.linspace(0.0, 100.0, 16 * 16).reshape(16, 16).astype("float32") +_BASE[0, 0] = np.nan + + +def _continuous_da(values): + n = values.shape[0] + return xr.DataArray( + values, + dims=("y", "x"), + coords={"y": np.arange(n, dtype="float64"), + "x": np.arange(values.shape[1], dtype="float64")}, + attrs={"crs": 4326}, + ) + + +def _qml_renderer(path): + """Parse a written ``.qml`` and return its element.""" + root = ET.parse(path).getroot() + return root.find(".//rasterrenderer") + + +# -------------------------------------------------------------------------- +# cross-backend write + parity +# -------------------------------------------------------------------------- + +def test_numpy_write_emits_both_sidecars(tmp_path): + path = str(tmp_path / "np.tif") + to_geotiff(_continuous_da(_BASE), path, color_ramp="viridis") + + # QML uses the extension-replacing name; PAM stats use the appended name. + assert os.path.exists(str(tmp_path / "np.qml")) + assert not os.path.exists(path + ".qml") + assert os.path.exists(path + ".aux.xml") + + rr = _qml_renderer(str(tmp_path / "np.qml")) + assert rr.get("type") == "singlebandpseudocolor" + assert rr.get("band") == "1" + # NaN hole excluded -> min is the second sample, not nan. + assert float(rr.get("classificationMin")) == pytest.approx( + float(np.nanmin(_BASE)), abs=1e-5) + assert float(rr.get("classificationMax")) == pytest.approx(100.0, abs=1e-5) + # One ramp item per control point. + items = rr.findall(".//colorrampshader/item") + assert len(items) == len(resolve_ramp("viridis")) + + +def test_dask_write_emits_sidecars(tmp_path): + import dask.array as dsa + + path = str(tmp_path / "dk.tif") + to_geotiff(_continuous_da(dsa.from_array(_BASE, chunks=(8, 8))), + path, color_ramp="plasma") + assert os.path.exists(str(tmp_path / "dk.qml")) + assert os.path.exists(path + ".aux.xml") + + +@requires_gpu +def test_gpu_write_emits_sidecars(tmp_path): + import cupy + + path = str(tmp_path / "gpu.tif") + to_geotiff(_continuous_da(cupy.asarray(_BASE)), path, + gpu=True, color_ramp="magma") + assert os.path.exists(str(tmp_path / "gpu.qml")) + assert os.path.exists(path + ".aux.xml") + + +def test_finite_stats_backend_parity(): + """Stats agree across numpy / dask / cupy / dask+cupy.""" + import dask.array as dsa + + ref = _finite_stats(_continuous_da(_BASE), None) + got = _finite_stats(_continuous_da(dsa.from_array(_BASE, chunks=(8, 8))), None) + assert got == pytest.approx(ref, abs=1e-5) + + pytest.importorskip("cupy") + import cupy + from xrspatial.utils import has_cuda_and_cupy + if not has_cuda_and_cupy(): + pytest.skip("no CUDA device") + cg = _finite_stats(_continuous_da(cupy.asarray(_BASE)), None) + cdg = _finite_stats( + _continuous_da(dsa.from_array(cupy.asarray(_BASE), chunks=(8, 8))), None) + assert cg == pytest.approx(ref, abs=1e-5) + assert cdg == pytest.approx(ref, abs=1e-5) + + +# -------------------------------------------------------------------------- +# _finite_stats unit behaviour +# -------------------------------------------------------------------------- + +def test_finite_stats_values(): + vmin, vmax, vmean, vstd = _finite_stats(_continuous_da(_BASE), None) + finite = _BASE[np.isfinite(_BASE)] + assert vmin == pytest.approx(float(finite.min()), abs=1e-5) + assert vmax == pytest.approx(float(finite.max()), abs=1e-5) + assert vmean == pytest.approx(float(finite.mean()), abs=1e-5) + assert vstd == pytest.approx(float(finite.std()), abs=1e-5) # ddof=0 + + +def test_finite_stats_all_nan_returns_none(): + arr = np.full((4, 4), np.nan, dtype="float32") + assert _finite_stats(_continuous_da(arr), None) is None + + +def test_finite_stats_excludes_nodata_sentinel(): + arr = np.array([[1.0, 2.0], [3.0, -9999.0]], dtype="float32") + vmin, vmax, _, _ = _finite_stats(_continuous_da(arr), nodata=-9999.0) + assert vmin == pytest.approx(1.0) + assert vmax == pytest.approx(3.0) + + +def test_finite_stats_constant_raster(): + arr = np.full((4, 4), 7.0, dtype="float32") + vmin, vmax, vmean, vstd = _finite_stats(_continuous_da(arr), None) + assert vmin == vmax == 7.0 + assert vstd == 0.0 + + +# -------------------------------------------------------------------------- +# QML / PAM building blocks +# -------------------------------------------------------------------------- + +def test_build_qml_structure(): + stops = resolve_ramp("viridis") + xml = build_qml(stops, 10.0, 20.0) + rr = ET.fromstring(xml).find(".//rasterrenderer") + assert rr.get("type") == "singlebandpseudocolor" + assert float(rr.get("classificationMin")) == pytest.approx(10.0) + assert float(rr.get("classificationMax")) == pytest.approx(20.0) + items = rr.findall(".//colorrampshader/item") + assert len(items) == len(stops) + # First/last item span the data range and carry the ramp endpoints. + assert float(items[0].get("value")) == pytest.approx(10.0) + assert float(items[-1].get("value")) == pytest.approx(20.0) + assert items[0].get("color") == "#%02x%02x%02x" % stops[0][1] + assert items[-1].get("color") == "#%02x%02x%02x" % stops[-1][1] + + +def test_qml_path_replaces_extension(): + assert qml_path("/a/b/dem.tif") == "/a/b/dem.qml" + assert qml_path("/a/b/dem.tiff") == "/a/b/dem.qml" + + +def test_resolve_ramp_true_is_viridis(): + assert resolve_ramp(True) == resolve_ramp("viridis") + assert resolve_ramp("VIRIDIS") == resolve_ramp("viridis") + + +def test_resolve_ramp_unknown_raises(): + with pytest.raises(ValueError, match="unknown color_ramp"): + resolve_ramp("not-a-ramp") + + +def test_build_stats_pam_xml_keys(): + stats = { + "STATISTICS_MINIMUM": 1.5, + "STATISTICS_MAXIMUM": 9.0, + "STATISTICS_MEAN": 5.25, + "STATISTICS_STDDEV": 2.0, + } + band = ET.fromstring(build_stats_pam_xml(stats)).find(".//PAMRasterBand") + assert band.get("band") == "1" + found = {mdi.get("key"): float(mdi.text) + for mdi in band.findall("./Metadata/MDI")} + assert found == pytest.approx(stats) + + +# -------------------------------------------------------------------------- +# gating +# -------------------------------------------------------------------------- + +def test_categorical_wins_no_qml(tmp_path): + """A categorical raster keeps the RAT sidecar; color_ramp is ignored.""" + cat = xr.DataArray( + np.array([[0, 1], [1, 0]], dtype="int32"), + dims=("y", "x"), + coords={"y": [1.0, 0.0], "x": [0.0, 1.0]}, + attrs={"crs": 4326, "category_names": ["a", "b"]}, + ) + path = str(tmp_path / "cat.tif") + to_geotiff(cat, path, color_ramp="viridis") + assert not os.path.exists(str(tmp_path / "cat.qml")) + assert "GDALRasterAttributeTable" in open(path + ".aux.xml").read() + + +def test_multiband_skipped(tmp_path): + rgb = xr.DataArray( + np.zeros((3, 4, 4), dtype="float32"), + dims=("band", "y", "x"), + coords={"band": [1, 2, 3], "y": np.arange(4.0), "x": np.arange(4.0)}, + attrs={"crs": 4326}, + ) + path = str(tmp_path / "rgb.tif") + to_geotiff(rgb, path, color_ramp="viridis") + assert not os.path.exists(str(tmp_path / "rgb.qml")) + assert not os.path.exists(path + ".aux.xml") + + +def test_filelike_destination_ignored(): + buf = io.BytesIO() + # No path to anchor sidecars; must not raise. + to_geotiff(_continuous_da(_BASE), buf, color_ramp="viridis") + assert buf.getbuffer().nbytes > 0 + + +def test_unknown_ramp_raises_before_write(tmp_path): + path = str(tmp_path / "bad.tif") + with pytest.raises(ValueError, match="unknown color_ramp"): + to_geotiff(_continuous_da(_BASE), path, color_ramp="nope") + assert not os.path.exists(path) + + +def test_constant_raster_writes_stats_skips_qml(tmp_path): + arr = np.full((8, 8), 42.0, dtype="float32") + path = str(tmp_path / "flat.tif") + to_geotiff(_continuous_da(arr), path, color_ramp="viridis") + assert os.path.exists(path + ".aux.xml") # stats still useful + assert not os.path.exists(str(tmp_path / "flat.qml")) # degenerate ramp + + +def test_pack_skips_symbology(tmp_path): + """``pack=True`` rewrites pixels, so symbology is skipped (would mismatch).""" + arr = np.linspace(0.0, 10.0, 8 * 8).reshape(8, 8).astype("float32") + da = xr.DataArray( + arr, dims=("y", "x"), + coords={"y": np.arange(8.0), "x": np.arange(8.0)}, + attrs={"crs": 4326, "scale_factor": 0.1, "add_offset": 0.0, + "mask_and_scale_dtype": "int16", "nodata": -1}, + ) + path = str(tmp_path / "packed.tif") + to_geotiff(da, path, pack=True, color_ramp="viridis") + assert not os.path.exists(str(tmp_path / "packed.qml")) + assert not os.path.exists(path + ".aux.xml") + + +def test_vrt_write_emits_sidecars(tmp_path): + """The .vrt write path also emits symbology via the shared closure.""" + import dask.array as dsa + + path = str(tmp_path / "mosaic.vrt") + to_geotiff(_continuous_da(dsa.from_array(_BASE, chunks=(8, 8))), + path, color_ramp="cividis") + assert os.path.exists(str(tmp_path / "mosaic.qml")) + assert os.path.exists(path + ".aux.xml") + + +def test_color_ramp_range_sets_bounds(tmp_path): + path = str(tmp_path / "rng.tif") + to_geotiff(_continuous_da(_BASE), path, + color_ramp="viridis", color_ramp_range=(0.0, 50.0)) + rr = _qml_renderer(str(tmp_path / "rng.qml")) + assert float(rr.get("classificationMin")) == pytest.approx(0.0) + assert float(rr.get("classificationMax")) == pytest.approx(50.0) + # range escape hatch writes only min/max stats (no mean/stddev pass). + aux = open(path + ".aux.xml").read() + assert "STATISTICS_MINIMUM" in aux and "STATISTICS_MEAN" not in aux