diff --git a/xrspatial/corridor.py b/xrspatial/corridor.py index 8d34ef517..1847f8be2 100644 --- a/xrspatial/corridor.py +++ b/xrspatial/corridor.py @@ -25,7 +25,7 @@ import xarray as xr from xrspatial.cost_distance import cost_distance -from xrspatial.utils import _validate_raster +from xrspatial.utils import _validate_matching_shape, _validate_raster def _scalar_to_float(da_scalar): @@ -207,6 +207,21 @@ def least_cost_corridor( # ------------------------------------------------------------------ if precomputed: cd_surfaces = list(sources) + # cost_distance() guarantees matching shapes when it runs, but the + # precomputed path bypasses it. Without this check, surfaces of + # differing shape get silently aligned on the intersection of their + # coordinates by xarray, producing a truncated, wrong-valued corridor. + # Like cost_distance, this guards shape only; mismatched coordinates + # on same-shape surfaces are out of scope. + expected_shape = cd_surfaces[0].shape + for i, cd in enumerate(cd_surfaces[1:], start=1): + _validate_matching_shape( + cd, + expected_shape, + func_name="least_cost_corridor", + name=f"precomputed surface {i}", + expected_name="precomputed surface 0", + ) # No friction surface to draw grid metadata from -- keep the # existing source-derived attrs/name behaviour. geo_source = None diff --git a/xrspatial/tests/test_corridor.py b/xrspatial/tests/test_corridor.py index 23eb2a326..c39fd0e14 100644 --- a/xrspatial/tests/test_corridor.py +++ b/xrspatial/tests/test_corridor.py @@ -376,6 +376,43 @@ def test_single_source_in_list_raises(): least_cost_corridor(friction, sources=[src]) +def test_precomputed_mismatched_shape_raises(): + """Precomputed surfaces of differing shape raise instead of aligning. + + Without the shape check, xarray silently aligns the two surfaces on the + intersection of their coordinates and returns a truncated corridor with + wrong values (e.g. 4x4 + 3x3 -> an all-zero 3x3 result). + """ + friction = _make_raster(np.ones((4, 4))) + cd_a = _make_raster(np.ones((4, 4))) + cd_b = _make_raster(np.ones((3, 3))) + with pytest.raises(ValueError, match="does not match"): + least_cost_corridor(friction, cd_a, cd_b, precomputed=True) + + +def test_precomputed_mismatched_shape_pairwise_raises(): + """Pairwise precomputed surfaces of differing shape raise.""" + friction = _make_raster(np.ones((4, 4))) + sources = [ + _make_raster(np.ones((4, 4))), + _make_raster(np.ones((3, 3))), + ] + with pytest.raises(ValueError, match="does not match"): + least_cost_corridor( + friction, sources=sources, precomputed=True, pairwise=True + ) + + +@pytest.mark.skipif(da is None, reason="dask not installed") +def test_precomputed_mismatched_shape_dask_raises(): + """The shape check fires on dask surfaces without triggering a compute.""" + friction = _make_raster(np.ones((4, 4)), backend="dask+numpy", chunks=(4, 4)) + cd_a = _make_raster(np.ones((4, 4)), backend="dask+numpy", chunks=(4, 4)) + cd_b = _make_raster(np.ones((3, 3)), backend="dask+numpy", chunks=(3, 3)) + with pytest.raises(ValueError, match="does not match"): + least_cost_corridor(friction, cd_a, cd_b, precomputed=True) + + # ----------------------------------------------------------------------- # Metadata propagation (issue #3446) # -----------------------------------------------------------------------