Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion xrspatial/corridor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import xarray as xr

from xrspatial.cost_distance import cost_distance
from xrspatial.utils import _validate_raster
from xrspatial.utils import _validate_matching_shape, _validate_raster


def _scalar_to_float(da_scalar):
Expand Down Expand Up @@ -207,6 +207,21 @@ def least_cost_corridor(
# ------------------------------------------------------------------
if precomputed:
cd_surfaces = list(sources)
# cost_distance() guarantees matching shapes when it runs, but the
# precomputed path bypasses it. Without this check, surfaces of
# differing shape get silently aligned on the intersection of their
# coordinates by xarray, producing a truncated, wrong-valued corridor.
# Like cost_distance, this guards shape only; mismatched coordinates
# on same-shape surfaces are out of scope.
expected_shape = cd_surfaces[0].shape
for i, cd in enumerate(cd_surfaces[1:], start=1):
_validate_matching_shape(
cd,
expected_shape,
func_name="least_cost_corridor",
name=f"precomputed surface {i}",
expected_name="precomputed surface 0",
)
# No friction surface to draw grid metadata from -- keep the
# existing source-derived attrs/name behaviour.
geo_source = None
Expand Down
37 changes: 37 additions & 0 deletions xrspatial/tests/test_corridor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,43 @@ def test_single_source_in_list_raises():
least_cost_corridor(friction, sources=[src])


def test_precomputed_mismatched_shape_raises():
"""Precomputed surfaces of differing shape raise instead of aligning.

Without the shape check, xarray silently aligns the two surfaces on the
intersection of their coordinates and returns a truncated corridor with
wrong values (e.g. 4x4 + 3x3 -> an all-zero 3x3 result).
"""
friction = _make_raster(np.ones((4, 4)))
cd_a = _make_raster(np.ones((4, 4)))
cd_b = _make_raster(np.ones((3, 3)))
with pytest.raises(ValueError, match="does not match"):
least_cost_corridor(friction, cd_a, cd_b, precomputed=True)


def test_precomputed_mismatched_shape_pairwise_raises():
"""Pairwise precomputed surfaces of differing shape raise."""
friction = _make_raster(np.ones((4, 4)))
sources = [
_make_raster(np.ones((4, 4))),
_make_raster(np.ones((3, 3))),
]
with pytest.raises(ValueError, match="does not match"):
least_cost_corridor(
friction, sources=sources, precomputed=True, pairwise=True
)


@pytest.mark.skipif(da is None, reason="dask not installed")
def test_precomputed_mismatched_shape_dask_raises():
"""The shape check fires on dask surfaces without triggering a compute."""
friction = _make_raster(np.ones((4, 4)), backend="dask+numpy", chunks=(4, 4))
cd_a = _make_raster(np.ones((4, 4)), backend="dask+numpy", chunks=(4, 4))
cd_b = _make_raster(np.ones((3, 3)), backend="dask+numpy", chunks=(3, 3))
with pytest.raises(ValueError, match="does not match"):
least_cost_corridor(friction, cd_a, cd_b, precomputed=True)


# -----------------------------------------------------------------------
# Metadata propagation (issue #3446)
# -----------------------------------------------------------------------
Expand Down
Loading