Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .claude/sweep-performance-state.csv

Large diffs are not rendered by default.

85 changes: 80 additions & 5 deletions xrspatial/geotiff/_symbology.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,76 @@ def _dask_finite_stats(arr, nodata):
utils._to_float_scalar(mean), utils._to_float_scalar(std))


def write_symbology_sidecars(path, data, *, stops, nodata=None, ramp_range=None):
class StreamingStats:
"""One-pass ``(min, max, mean, stddev)`` accumulator over streamed buffers.

The streaming dask writer already materialises every pixel exactly once
(issue #3597); feeding those buffers through :meth:`update` yields the
same statistics as :func:`_finite_stats` without a second execution of
the source graph. Matches ``_finite_stats`` semantics: non-finite values
and the *nodata* sentinel (unless it is NaN) are excluded, and the
stddev is the population stddev (``ddof=0``).

Mean and variance combine across buffers with Chan's parallel formula,
accumulated in float64, so the result is stable regardless of how the
writer bands / segments the raster.
"""

def __init__(self, nodata=None):
if isinstance(nodata, float) and math.isnan(nodata):
nodata = None
self._nodata = nodata
self._count = 0
self._mean = 0.0
self._m2 = 0.0
self._min = math.inf
self._max = -math.inf

def update(self, buf):
"""Fold one materialised numpy buffer into the running statistics."""
if buf.dtype.kind == 'f':
mask = np.isfinite(buf)
if self._nodata is not None:
mask &= (buf != self._nodata)
# ``buf[mask]`` copies; skip it for the common all-valid buffer.
vals = buf.ravel() if bool(mask.all()) else buf[mask]
elif self._nodata is not None:
mask = buf != self._nodata
vals = buf.ravel() if bool(mask.all()) else buf[mask]
else:
vals = buf.ravel()
# Plain ``int`` so the running count (and the moment arithmetic
# built on it) stays in unbounded Python ints, not numpy scalars.
n_b = int(vals.size)
if n_b == 0:
return
# Reductions with float64 accumulators on the original buffer:
# an ``astype(float64)`` copy here would double the memory
# traffic of every update for no precision gain.
mean_b = float(vals.mean(dtype=np.float64))
m2_b = float(vals.var(dtype=np.float64)) * n_b
self._min = min(self._min, float(vals.min()))
self._max = max(self._max, float(vals.max()))
if self._count == 0:
self._count, self._mean, self._m2 = n_b, mean_b, m2_b
else:
total = self._count + n_b
delta = mean_b - self._mean
self._m2 += m2_b + delta * delta * self._count * n_b / total
self._mean += delta * n_b / total
self._count = total

def result(self):
"""Return ``(min, max, mean, stddev)``, or ``None`` if nothing valid
was accumulated -- the same contract as :func:`_finite_stats`."""
if self._count == 0:
return None
std = math.sqrt(max(self._m2 / self._count, 0.0))
return (self._min, self._max, self._mean, std)


def write_symbology_sidecars(path, data, *, stops, nodata=None,
ramp_range=None, stats=None):
"""Write continuous-raster symbology sidecars next to *path*.

Writes band statistics into the PAM ``.aux.xml`` and, when the data range
Expand All @@ -241,6 +310,11 @@ def write_symbology_sidecars(path, data, *, stops, nodata=None, ramp_range=None)
directly and skips the statistics reduction -- the escape hatch for large
dask graphs -- so only ``STATISTICS_MINIMUM`` / ``STATISTICS_MAXIMUM`` are
written (mean/stddev would need the pass it avoids).

*stats* is an optional :class:`StreamingStats` that already accumulated
the statistics during the write (the streaming dask path, issue #3597);
when given, its result replaces the :func:`_finite_stats` reduction so
the source graph is not executed a second time.
"""
from . import _pam

Expand All @@ -249,20 +323,21 @@ def write_symbology_sidecars(path, data, *, stops, nodata=None, ramp_range=None)

if ramp_range is not None:
vmin, vmax = float(ramp_range[0]), float(ramp_range[1])
stats = {'STATISTICS_MINIMUM': vmin, 'STATISTICS_MAXIMUM': vmax}
stats_dict = {'STATISTICS_MINIMUM': vmin, 'STATISTICS_MAXIMUM': vmax}
else:
result = _finite_stats(data, nodata)
result = (stats.result() if stats is not None
else _finite_stats(data, nodata))
if result is None:
return
vmin, vmax, vmean, vstd = result
stats = {
stats_dict = {
'STATISTICS_MINIMUM': vmin,
'STATISTICS_MAXIMUM': vmax,
'STATISTICS_MEAN': vmean,
'STATISTICS_STDDEV': vstd,
}

_pam.write_stats_pam_sidecar(path, stats)
_pam.write_stats_pam_sidecar(path, stats_dict)

# A constant raster (vmin == vmax) has no range to ramp across, so the
# stats stretch is the useful part and the QML would be a degenerate
Expand Down
18 changes: 17 additions & 1 deletion xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,8 @@ def _write_streaming(dask_data, path: str, *,
photometric='auto',
restore_sentinel: bool = True,
allow_internal_only_jpeg: bool = False,
allow_unparseable_crs: bool = False) -> None:
allow_unparseable_crs: bool = False,
chunk_observer=None) -> None:
"""Write a dask array as a GeoTIFF by streaming pixel data.

For tiled output, each tile-row is computed in horizontal segments
Expand Down Expand Up @@ -818,6 +819,15 @@ def _write_streaming(dask_data, path: str, *,
allow_unparseable_crs : bool
Opt in to writing an unparseable ``crs_wkt`` string into
``GTCitationGeoKey``. Default ``False``.
chunk_observer : callable, optional
Called once with every buffer the streaming write materialises
(a numpy array of logical values: after the dask compute and any
CuPy transfer, before the output-dtype cast and the
NaN-to-sentinel restore). The buffers partition the raster, so
an observer sees each pixel exactly once -- ``to_geotiff`` uses
this to accumulate the ``color_ramp`` symbology statistics
during the write instead of executing the source graph a second
time (issue #3597).

Notes
-----
Expand Down Expand Up @@ -1215,6 +1225,8 @@ def _write_streaming(dask_data, path: str, *,
if hasattr(band_np, 'get'):
band_np = band_np.get() # CuPy -> numpy
band_np = np.asarray(band_np)
if chunk_observer is not None:
chunk_observer(band_np)

if band_np.dtype != out_dtype:
band_np = band_np.astype(out_dtype)
Expand Down Expand Up @@ -1263,6 +1275,8 @@ def _write_streaming(dask_data, path: str, *,
if hasattr(seg_np, 'get'):
seg_np = seg_np.get() # CuPy -> numpy
seg_np = np.asarray(seg_np)
if chunk_observer is not None:
chunk_observer(seg_np)

if seg_np.dtype != out_dtype:
seg_np = seg_np.astype(out_dtype)
Expand Down Expand Up @@ -1381,6 +1395,8 @@ def _write_streaming(dask_data, path: str, *,
if hasattr(band_np, 'get'):
band_np = band_np.get() # CuPy -> numpy
band_np = np.asarray(band_np)
if chunk_observer is not None:
chunk_observer(band_np)

if band_np.dtype != out_dtype:
band_np = band_np.astype(out_dtype)
Expand Down
38 changes: 28 additions & 10 deletions xrspatial/geotiff/_writers/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,17 +414,20 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
``<file>.aux.xml``. No-op for a categorical raster (one with
``attrs['category_names']`` -- those get the RAT sidecar instead), a
multiband array, a file-like destination, or data with no finite
values. Computing the statistics is a separate reduction pass over the
data; for a dask source that means reading the graph once more (see
``color_ramp_range`` to skip it). Ignored when ``pack=True``, whose
on-disk packed values would not match a ramp built from the logical
values.
values. Computing the statistics is an extra reduction pass over the
data. The streaming dask write accumulates them from the buffers it
materialises anyway, so the source graph still runs once; the GPU
(``gpu=True``) and VRT (``.vrt``) write paths execute a dask source
a second time for the statistics (see ``color_ramp_range`` to skip
that). Ignored when ``pack=True``, whose on-disk packed values would
not match a ramp built from the logical values.
color_ramp_range : tuple of (float, float) or None, default None
[advanced] Explicit ``(min, max)`` for the ``color_ramp`` stretch.
Skips the statistics reduction -- useful for large dask graphs -- so
only ``STATISTICS_MINIMUM`` / ``STATISTICS_MAXIMUM`` are written
(mean/stddev need the pass it avoids). Ignored when ``color_ramp`` is
not set.
Skips the statistics reduction -- useful for a dask source on the
GPU or VRT write paths, which would otherwise read the graph once
more -- so only ``STATISTICS_MINIMUM`` / ``STATISTICS_MAXIMUM`` are
written (mean/stddev need the pass it avoids). Ignored when
``color_ramp`` is not set.

Returns
-------
Expand Down Expand Up @@ -491,6 +494,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
_sym_data = None
_sym_stops = None
_sym_nodata = None
_sym_stream_stats = None
if isinstance(path, str) and isinstance(data, xr.DataArray):
_cat_names = data.attrs.get('category_names')
_cat_colors = data.attrs.get('category_colors')
Expand All @@ -515,7 +519,8 @@ def _write_sidecars():
from .._symbology import write_symbology_sidecars
write_symbology_sidecars(
path, _sym_data, stops=_sym_stops,
nodata=_sym_nodata, ramp_range=color_ramp_range)
nodata=_sym_nodata, ramp_range=color_ramp_range,
stats=_sym_stream_stats)

# Reject bool / np.bool_ nodata up front. ``bool`` is a subclass of
# ``int`` in Python, so a typo like ``nodata=True`` slips past every
Expand Down Expand Up @@ -1097,6 +1102,18 @@ def _write_sidecars():
# branch needs the guard.
if epsg is None:
_validate_crs_fallback(wkt_fallback, allow_unparseable_crs)
# ``color_ramp`` statistics: rather than re-executing the
# source graph after the write (a second full read of the
# data, issue #3597), accumulate them from the buffers the
# streaming write materialises anyway. ``color_ramp_range``
# already skips statistics, and multiband data never gets
# symbology, so neither pays the accumulation cost.
_stream_chunk_observer = None
if _sym_stops is not None and color_ramp_range is None:
from .._symbology import StreamingStats, _is_single_band
if _is_single_band(data):
_sym_stream_stats = StreamingStats(nodata=_sym_nodata)
_stream_chunk_observer = _sym_stream_stats.update
write_streaming(
dask_arr, path,
geo_transform=geo_transform,
Expand Down Expand Up @@ -1125,6 +1142,7 @@ def _write_sidecars():
# rather than rejecting input the wrapper accepted.
allow_internal_only_jpeg=allow_internal_only_jpeg,
allow_unparseable_crs=allow_unparseable_crs,
chunk_observer=_stream_chunk_observer,
)
_write_sidecars()
return path
Expand Down
Loading
Loading