Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion xrspatial/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ def _available_memory_bytes():

@ngjit
def _cpu_binary(data, values):
out = np.empty(data.shape, dtype=data.dtype)
# Output float32 to match the cupy/dask+cupy backends (which always
# allocate 'f4') and the other classifiers, which route through _cpu_bin
# and likewise emit float32. Preserving the input dtype here made binary()
# the lone op whose result dtype diverged across backends (float64 on
# numpy/dask vs float32 on cupy) and could not hold the NaN sentinel for
# integer input.
out = np.empty(data.shape, dtype=np.float32)
out[:] = np.nan
rows, cols = data.shape
for y in range(0, rows):
Expand Down
26 changes: 22 additions & 4 deletions xrspatial/tests/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ def test_binary_numpy(result_binary):
values, expected_result = result_binary
numpy_agg = input_data()
numpy_result = binary(numpy_agg, values)
general_output_checks(numpy_agg, numpy_result, expected_result)
general_output_checks(numpy_agg, numpy_result, expected_result, verify_dtype=True)


@dask_array_available
def test_binary_dask_numpy(result_binary):
values, expected_result = result_binary
dask_agg = input_data(backend='dask')
dask_result = binary(dask_agg, values)
general_output_checks(dask_agg, dask_result, expected_result)
general_output_checks(dask_agg, dask_result, expected_result, verify_dtype=True)


@cuda_and_cupy_available
def test_binary_cupy(result_binary):
values, expected_result = result_binary
cupy_agg = input_data(backend='cupy')
cupy_result = binary(cupy_agg, values)
general_output_checks(cupy_agg, cupy_result, expected_result)
general_output_checks(cupy_agg, cupy_result, expected_result, verify_dtype=True)


@dask_array_available
Expand All @@ -69,7 +69,25 @@ def test_binary_dask_cupy(result_binary):
values, expected_result = result_binary
dask_cupy_agg = input_data(backend='dask+cupy')
dask_cupy_result = binary(dask_cupy_agg, values)
general_output_checks(dask_cupy_agg, dask_cupy_result, expected_result)
general_output_checks(dask_cupy_agg, dask_cupy_result, expected_result, verify_dtype=True)


def test_binary_output_dtype_float32():
# binary() must emit float32 regardless of input dtype so its result
# dtype matches the cupy/dask+cupy backends and the other classifiers
# (regression for the numpy/dask paths returning the input dtype).
for in_dtype in (np.float64, np.float32, np.int32):
data = np.array([[1, 2, 3], [4, 5, 6]], dtype=in_dtype)
result = binary(xr.DataArray(data), [2, 5])
assert result.data.dtype == np.float32


@dask_array_available
def test_binary_dask_output_dtype_float32():
data = np.array([[1., 2., 3.], [4., 5., 6.]], dtype=np.float64)
dask_agg = xr.DataArray(da.from_array(data, chunks=(1, 3)))
result = binary(dask_agg, [2, 5])
assert result.data.compute().dtype == np.float32


@pytest.fixture
Expand Down
Loading