Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions xrspatial/tests/test_cost_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,127 @@ def test_cupy_memory_guard_passes_for_small_raster():
out = _compute(result)
assert out[0, 0] == 0.0
np.testing.assert_allclose(out[0, 1], 1.0, atol=1e-5)


# -----------------------------------------------------------------------
# Single-pixel (1x1) raster — degenerate no-neighbour case (Issue #3341)
# -----------------------------------------------------------------------

@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy', 'dask+cupy'])
def test_single_pixel_source(backend):
"""A 1x1 source raster has no neighbours; the source cost is 0."""
source = np.array([[1.0]])
friction_data = np.array([[1.0]])

raster = _make_raster(source, backend=backend, chunks=(1, 1))
friction = _make_raster(friction_data, backend=backend, chunks=(1, 1))

result = cost_distance(raster, friction)
out = _compute(result)

assert out.shape == (1, 1)
assert out[0, 0] == 0.0


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy', 'dask+cupy'])
def test_single_pixel_no_source(backend):
"""A 1x1 raster with no source pixel is all-NaN (nothing reachable)."""
source = np.array([[0.0]])
friction_data = np.array([[1.0]])

raster = _make_raster(source, backend=backend, chunks=(1, 1))
friction = _make_raster(friction_data, backend=backend, chunks=(1, 1))

result = cost_distance(raster, friction)
out = _compute(result)

assert out.shape == (1, 1)
assert np.isnan(out[0, 0])


# -----------------------------------------------------------------------
# Inf friction is impassable, like NaN (Issue #3341)
# -----------------------------------------------------------------------

@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy', 'dask+cupy'])
def test_inf_friction_is_impassable(backend):
"""An Inf-friction cell blocks paths just like NaN/zero friction."""
source = np.zeros((3, 3))
source[0, 0] = 1.0

friction_data = np.ones((3, 3))
friction_data[1, 1] = np.inf # impassable barrier at the centre

raster = _make_raster(source, backend=backend, chunks=(3, 3))
friction = _make_raster(friction_data, backend=backend, chunks=(3, 3))

result = cost_distance(raster, friction)
out = _compute(result)

# The Inf cell itself is unreachable (cannot be traversed onto).
assert np.isnan(out[1, 1])
# Cells around the barrier are still reachable by routing around it.
assert np.isfinite(out[0, 1])
assert np.isfinite(out[1, 0])
assert np.isfinite(out[2, 2])


# -----------------------------------------------------------------------
# Metadata preservation: attrs / coords / dims survive the call (Issue #3341)
# -----------------------------------------------------------------------

def _make_meta_raster(data, backend='numpy', chunks=(3, 3)):
"""Like _make_raster but with non-default attrs and named coords."""
h, w = data.shape
raster = xr.DataArray(
data.astype(np.float64),
dims=['y', 'x'],
attrs={'res': (2.0, 3.0), 'crs': 'EPSG:5070', 'units': 'm'},
)
raster['y'] = np.linspace((h - 1) * 2.0, 0, h)
raster['x'] = np.linspace(0, (w - 1) * 3.0, w)
if 'dask' in backend and da is not None:
raster.data = da.from_array(raster.data, chunks=chunks)
if 'cupy' in backend and has_cuda_and_cupy():
import cupy
if isinstance(raster.data, da.Array):
raster.data = raster.data.map_blocks(cupy.asarray)
else:
raster.data = cupy.asarray(raster.data)
return raster


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy', 'dask+cupy'])
def test_metadata_preserved(backend):
"""Output must carry the input's attrs, coords, and dim names."""
source = np.zeros((6, 6))
source[0, 0] = 1.0
friction_data = np.ones((6, 6))

raster = _make_meta_raster(source, backend=backend, chunks=(3, 3))
friction = _make_meta_raster(friction_data, backend=backend, chunks=(3, 3))

result = cost_distance(raster, friction, max_cost=20.0)

assert result.dims == raster.dims
assert result.attrs == raster.attrs
np.testing.assert_array_equal(result['x'].data, raster['x'].data)
np.testing.assert_array_equal(result['y'].data, raster['y'].data)


def test_custom_dim_names_preserved():
"""lat/lon dim names must not be silently renamed to y/x."""
data = np.zeros((4, 4))
data[0, 0] = 1.0
raster = xr.DataArray(data, dims=['lat', 'lon'], attrs={'res': (1.0, 1.0)})
raster['lat'] = np.arange(4, dtype=np.float64)
raster['lon'] = np.arange(4, dtype=np.float64)
friction = xr.DataArray(np.ones((4, 4)), dims=['lat', 'lon'],
attrs={'res': (1.0, 1.0)})
friction['lat'] = np.arange(4, dtype=np.float64)
friction['lon'] = np.arange(4, dtype=np.float64)

result = cost_distance(raster, friction, x='lon', y='lat')

assert result.dims == ('lat', 'lon')
assert 'lat' in result.coords and 'lon' in result.coords
Loading