Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions xrspatial/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
# there -- a large chunked study area is a legitimate workflow.
_MAX_CELLS = 500_000_000

# A lazy dask grid never materializes, but its task graph does: with a fixed
# chunk size the block count grows with the cell count, and a typo-level
# resolution can build a graph large enough to bog down the client during
# construction. Guard on the estimated block count (not cells) so a legitimate
# large grid with sensible chunks still passes -- the new_england@10m repro is
# ~26k blocks, so 1e6 leaves wide headroom while catching the runaway case.
_MAX_CHUNKS = 1_000_000


def _resolve(name):
"""Resolve a name to a spec dict.
Expand Down Expand Up @@ -180,6 +188,23 @@ def _make_data(shape, fill, backend, chunks):
)


def _estimate_n_chunks(shape, chunks):
"""Number of blocks ``da.full`` would build for ``shape`` and ``chunks``.

Uses dask's own ``normalize_chunks`` so every ``chunks`` form -- ``'auto'``
(where dask picks the block shape from the chunk-size config), an int, a
tuple, or a dict -- resolves to the same block grid the real array would
have. dtype is ``float32`` to match the grid ``_make_data`` builds. The call
only sizes the block tuples, so it stays cheap even when the count is huge.
"""
from dask.array.core import normalize_chunks
normalized = normalize_chunks(chunks, shape=shape, dtype="float32")
n = 1
for dim in normalized:
n *= len(dim)
return n


def _cf_crs_attrs(crs):
"""CF Conventions grid-mapping attributes for an EPSG code.

Expand Down Expand Up @@ -316,8 +341,9 @@ def from_template(name: str,
longer applies. When omitted, the dask backends use ``'auto'``. The
data stays lazy, but a very fine resolution still builds one task per
chunk, so an extreme shape with small chunks can make a task graph
large enough to bog down the client; coarsen the resolution or use
larger chunks if that happens.
large enough to bog down the client. To prevent that, a grid that would
split into more than 1,000,000 chunks raises ``ValueError``; coarsen the
resolution or use larger chunks.

Returns
-------
Expand Down Expand Up @@ -410,6 +436,19 @@ def from_template(name: str,
f"Use a coarser resolution, or pass chunks=... for a lazy dask grid."
)

effective_chunks = "auto" if chunks is None else chunks
if is_dask:
n_chunks = _estimate_n_chunks((height, width), effective_chunks)
if n_chunks > _MAX_CHUNKS:
raise ValueError(
f"resolution {(res_x, res_y)} produces a {height} x {width} "
f"grid that splits into {n_chunks:,} chunks with "
f"chunks={effective_chunks!r}, exceeding the "
f"{_MAX_CHUNKS:,}-chunk limit. A task graph this large can bog "
f"down the client even though no data is computed. Use a "
f"coarser resolution or larger chunks."
)

# Honor the requested resolution exactly: anchor the lower-left corner and
# nudge the far edges to an exact multiple of the cell size, so res comes
# back as (res_x, res_y) instead of drifting when the bbox extent isn't a
Expand All @@ -419,8 +458,7 @@ def from_template(name: str,
top = bottom + height * res_y
ys, xs = _make_output_coords((left, bottom, right, top), (height, width))

data = _make_data((height, width), fill, backend,
"auto" if chunks is None else chunks)
data = _make_data((height, width), fill, backend, effective_chunks)

attrs = {"res": (res_x, res_y), "crs": crs}
attrs.update(_cf_crs_attrs(crs))
Expand Down
53 changes: 53 additions & 0 deletions xrspatial/tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,59 @@ def test_chunks_promotes_cupy_to_dask_cupy():
assert isinstance(block, cupy.ndarray)


@dask_array_available
def test_over_fine_dask_chunk_count_raises():
# The dask path skips the cell cap, but a typo-level resolution with a fixed
# chunk size builds a runaway task graph. This is the issue #3557 repro:
# conus at 1 m / chunks=512 is ~7e13 cells / 512^2 ~= 7e7 chunks. The guard
# must raise from the estimate, BEFORE da.full builds the graph. Match the
# chunk-count cap text specifically so this can't pass on the eager
# cell-cap message (which also mentions "chunks").
with pytest.raises(ValueError, match="chunk limit"):
from_template("conus", resolution=1, chunks=512)


@dask_array_available
def test_explicit_dask_backend_chunk_count_raises():
# Same guard via the non-promotion path: an explicit dask backend with a
# fixed small chunk size on a typo-fine resolution must raise too.
with pytest.raises(ValueError, match="chunk limit"):
from_template("conus", resolution=1, backend="dask+numpy", chunks=512)


@dask_array_available
def test_auto_chunks_exempt_from_chunk_cap():
import dask.array as da
# 'auto' sizes blocks to the dask chunk-size config (~128 MB), so even a very
# fine resolution stays well under the chunk cap and builds fine. The guard
# keys on the real block count, so the auto path is not falsely tripped.
agg = from_template("conus", resolution=1, chunks="auto")
assert isinstance(agg.data, da.Array)
from xrspatial.templates import _MAX_CHUNKS
assert agg.data.npartitions <= _MAX_CHUNKS


@dask_array_available
def test_chunk_count_estimate_matches_dask():
import dask.array as da
from xrspatial.templates import _estimate_n_chunks
# The estimate must agree with the block count dask actually builds, across
# chunk forms, so the guard fires on the real graph size.
for chunks in (256, 512, "auto", (300, 400)):
built = da.full((4000, 5000), np.nan, dtype="float32", chunks=chunks)
assert _estimate_n_chunks((4000, 5000), chunks) == built.npartitions


@dask_array_available
def test_legit_large_dask_grid_passes():
import dask.array as da
# The headroom case: new_england @ 10 m / chunks=512 is past the eager cell
# cap but only ~26k chunks, far below the 1e6 chunk cap, so it must build.
agg = from_template("new_england", resolution=10, chunks=512)
assert isinstance(agg.data, da.Array)
assert agg.data.npartitions < 1_000_000


def test_single_pixel_grid():
# a resolution coarser than the whole study-area box clamps width and height
# to the max(1, ...) floor, giving a 1x1 grid that still obeys the contract.
Expand Down
Loading