Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions xrspatial/perlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# std lib
import math
from functools import partial

# 3rd-party
import numpy as np
Expand Down Expand Up @@ -112,6 +111,7 @@ def _perlin_dask_numpy(data: da.Array,
freq: tuple,
seed: int) -> da.Array:
p = _make_perm_table(seed)
out_dtype = data.dtype

height, width = data.shape
linx = da.linspace(0, freq[0], width, endpoint=False, dtype=np.float32,
Expand All @@ -120,8 +120,10 @@ def _perlin_dask_numpy(data: da.Array,
chunks=data.chunks[0][0])
x, y = da.meshgrid(linx, liny)

_func = partial(_perlin, p)
data = da.map_blocks(_func, x, y, meta=np.array((), dtype=np.float32),
def _func(x_blk, y_blk):
return _perlin(p, x_blk, y_blk).astype(out_dtype)

data = da.map_blocks(_func, x, y, meta=np.array((), dtype=out_dtype),
**_dask_task_name_kwargs('xrspatial.perlin'))

# min and ptp go out in one dask.compute call, which shares the noise
Expand Down Expand Up @@ -259,6 +261,7 @@ def _perlin_dask_cupy(data: da.Array,
freq: tuple,
seed: int) -> da.Array:
p = cupy.asarray(_make_perm_table(seed))
out_dtype = data.dtype

height, width = data.shape

Expand All @@ -271,13 +274,13 @@ def _chunk_perlin(block, block_info=None):
y0 = freq[1] * y_start / height
y1 = freq[1] * y_end / height

out = cupy.empty(block.shape, dtype=cupy.float32)
out = cupy.empty(block.shape, dtype=out_dtype)
griddim, blockdim = cuda_args(block.shape)
_perlin_gpu[griddim, blockdim](p, x0, x1, y0, y1, 1.0, out)
return out

data = da.map_blocks(_chunk_perlin, data, dtype=cupy.float32,
meta=cupy.array((), dtype=cupy.float32),
data = da.map_blocks(_chunk_perlin, data, dtype=out_dtype,
meta=cupy.array((), dtype=out_dtype),
**_dask_task_name_kwargs('xrspatial.perlin'))

# min and max go out in one dask.compute call, which shares the noise
Expand Down
37 changes: 37 additions & 0 deletions xrspatial/tests/test_perlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,43 @@ def test_perlin_float64_input():
assert result.data.max() <= 1.0


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_perlin_dask_cpu_preserves_dtype(dtype):
# Regression: the dask backends used to hardcode float32, silently
# downcasting float64 input while numpy/cupy preserved it.
import dask.array as da
data = da.from_array(np.zeros((20, 20), dtype=dtype), chunks=(10, 10))
raster = xr.DataArray(data, dims=['y', 'x'])
result = perlin(raster)
assert result.dtype == dtype
computed = result.data.compute()
assert np.isfinite(computed).all()
assert computed.min() >= 0.0
assert computed.max() <= 1.0


@cuda_and_cupy_available
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_perlin_gpu_preserves_dtype(dtype):
import cupy
data = cupy.zeros((20, 20), dtype=dtype)
raster = xr.DataArray(data, dims=['y', 'x'])
result = perlin(raster)
assert result.dtype == dtype


@cuda_and_cupy_available
@dask_array_available
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_perlin_dask_gpu_preserves_dtype(dtype):
import cupy
import dask.array as da
data = da.from_array(cupy.zeros((20, 20), dtype=dtype), chunks=(10, 10))
raster = xr.DataArray(data, dims=['y', 'x'])
result = perlin(raster)
assert result.dtype == dtype


# ---------------------------------------------------------------------------
# Parameter coverage (freq / seed / name)
# ---------------------------------------------------------------------------
Expand Down
Loading