Skip to content

Commit 2db65f8

Browse files
committed
binary(): type the dask+cupy meta so the lazy dtype matches float32
_run_dask_cupy_binary passed meta=cupy.array(()), which defaults to float64, so the lazy dask array advertised float64 while the chunk function _run_cupy_binary computes float32. The eager dtype fix in the previous commit did not cover the advertised dask+cupy dtype, and the existing tests only check the computed dtype (general_output_checks computes before asserting), so the mismatch went unnoticed. Pass meta=cupy.array((), dtype='f4') and assert the lazy dtype in test_binary_dask_cupy. Same advertised-vs-computed class as aspect #2682 and focal #3217.
1 parent 9de6889 commit 2db65f8

2 files changed

Lines changed: 5 additions & 1 deletion

File tree

xrspatial/classify.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def _run_cupy_binary(data, values):
9898

9999

100100
def _run_dask_cupy_binary(data, values_cupy):
101-
out = data.map_blocks(lambda da: _run_cupy_binary(da, values_cupy), meta=cupy.array(()),
101+
out = data.map_blocks(lambda da: _run_cupy_binary(da, values_cupy),
102+
meta=cupy.array((), dtype='f4'),
102103
**_dask_task_name_kwargs('xrspatial.binary'))
103104
return out
104105

xrspatial/tests/test_classify.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def test_binary_dask_cupy(result_binary):
6969
values, expected_result = result_binary
7070
dask_cupy_agg = input_data(backend='dask+cupy')
7171
dask_cupy_result = binary(dask_cupy_agg, values)
72+
# the lazy dask array must advertise the same dtype it computes, otherwise
73+
# a downstream consumer reads float64 metadata for a float32 result
74+
assert dask_cupy_result.data.dtype == np.float32
7275
general_output_checks(dask_cupy_agg, dask_cupy_result, expected_result, verify_dtype=True)
7376

7477

0 commit comments

Comments
 (0)