diff --git a/xrspatial/perlin.py b/xrspatial/perlin.py index 5c82a2525..eb4e81896 100644 --- a/xrspatial/perlin.py +++ b/xrspatial/perlin.py @@ -2,7 +2,6 @@ # std lib import math -from functools import partial # 3rd-party import numpy as np @@ -112,6 +111,7 @@ def _perlin_dask_numpy(data: da.Array, freq: tuple, seed: int) -> da.Array: p = _make_perm_table(seed) + out_dtype = data.dtype height, width = data.shape linx = da.linspace(0, freq[0], width, endpoint=False, dtype=np.float32, @@ -120,8 +120,10 @@ def _perlin_dask_numpy(data: da.Array, chunks=data.chunks[0][0]) x, y = da.meshgrid(linx, liny) - _func = partial(_perlin, p) - data = da.map_blocks(_func, x, y, meta=np.array((), dtype=np.float32), + def _func(x_blk, y_blk): + return _perlin(p, x_blk, y_blk).astype(out_dtype) + + data = da.map_blocks(_func, x, y, meta=np.array((), dtype=out_dtype), **_dask_task_name_kwargs('xrspatial.perlin')) # min and ptp go out in one dask.compute call, which shares the noise @@ -259,6 +261,7 @@ def _perlin_dask_cupy(data: da.Array, freq: tuple, seed: int) -> da.Array: p = cupy.asarray(_make_perm_table(seed)) + out_dtype = data.dtype height, width = data.shape @@ -271,13 +274,13 @@ def _chunk_perlin(block, block_info=None): y0 = freq[1] * y_start / height y1 = freq[1] * y_end / height - out = cupy.empty(block.shape, dtype=cupy.float32) + out = cupy.empty(block.shape, dtype=out_dtype) griddim, blockdim = cuda_args(block.shape) _perlin_gpu[griddim, blockdim](p, x0, x1, y0, y1, 1.0, out) return out - data = da.map_blocks(_chunk_perlin, data, dtype=cupy.float32, - meta=cupy.array((), dtype=cupy.float32), + data = da.map_blocks(_chunk_perlin, data, dtype=out_dtype, + meta=cupy.array((), dtype=out_dtype), **_dask_task_name_kwargs('xrspatial.perlin')) # min and max go out in one dask.compute call, which shares the noise diff --git a/xrspatial/tests/test_perlin.py b/xrspatial/tests/test_perlin.py index 547f4c645..a96a8983f 100644 --- a/xrspatial/tests/test_perlin.py +++ b/xrspatial/tests/test_perlin.py @@ -163,6 +163,43 @@ def test_perlin_float64_input(): assert result.data.max() <= 1.0 +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_perlin_dask_cpu_preserves_dtype(dtype): + # Regression: the dask backends used to hardcode float32, silently + # downcasting float64 input while numpy/cupy preserved it. + import dask.array as da + data = da.from_array(np.zeros((20, 20), dtype=dtype), chunks=(10, 10)) + raster = xr.DataArray(data, dims=['y', 'x']) + result = perlin(raster) + assert result.dtype == dtype + computed = result.data.compute() + assert np.isfinite(computed).all() + assert computed.min() >= 0.0 + assert computed.max() <= 1.0 + + +@cuda_and_cupy_available +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_perlin_gpu_preserves_dtype(dtype): + import cupy + data = cupy.zeros((20, 20), dtype=dtype) + raster = xr.DataArray(data, dims=['y', 'x']) + result = perlin(raster) + assert result.dtype == dtype + + +@cuda_and_cupy_available +@dask_array_available +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_perlin_dask_gpu_preserves_dtype(dtype): + import cupy + import dask.array as da + data = da.from_array(cupy.zeros((20, 20), dtype=dtype), chunks=(10, 10)) + raster = xr.DataArray(data, dims=['y', 'x']) + result = perlin(raster) + assert result.dtype == dtype + + # --------------------------------------------------------------------------- # Parameter coverage (freq / seed / name) # ---------------------------------------------------------------------------