Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def apply_where( # numpydoc ignore=GL08
f2: Callable[..., Array],
/,
*,
kwargs: dict[str, Array] | None = None,
xp: ModuleType | None = None,
) -> Array: ...

Expand All @@ -53,6 +54,7 @@ def apply_where( # numpydoc ignore=GL08
/,
*,
fill_value: Array | complex,
kwargs: dict[str, Array] | None = None,
xp: ModuleType | None = None,
) -> Array: ...

Expand All @@ -65,6 +67,7 @@ def apply_where( # numpydoc ignore=PR01,PR02
/,
*,
fill_value: Array | complex | None = None,
kwargs: dict[str, Array] | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Expand All @@ -91,6 +94,9 @@ def apply_where( # numpydoc ignore=PR01,PR02
It does not need to be scalar; it needs however to be broadcastable with
`cond` and `args`.
Mutually exclusive with `f2`. You must provide one or the other.
kwargs : dict of str : Array pairs
Keyword argument(s) to `f1` (and `f2`). Values must be broadcastable with
`cond`.
xp : array_namespace, optional
The standard-compatible namespace for `cond` and `args`. Default: infer.

Expand Down Expand Up @@ -129,6 +135,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
args_ = list(args) if isinstance(args, tuple) else [args]
del args

kwargs_ = {} if kwargs is None else kwargs
kwkeys = list(kwargs_.keys())
args_ = [*args_, *kwargs_.values()]
del kwargs

xp = array_namespace(cond, fill_value, *args_) if xp is None else xp

if isinstance(fill_value, int | float | complex | NoneType):
Expand All @@ -139,8 +150,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
if is_dask_namespace(xp):
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
# map_blocks doesn't descend into tuples of Arrays
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
return xp.map_blocks(
_apply_where, cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=meta_xp
)

return _apply_where(cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=xp)


def _apply_where( # numpydoc ignore=PR01,RT01
Expand All @@ -149,15 +163,26 @@ def _apply_where( # numpydoc ignore=PR01,RT01
f2: Callable[..., Array] | None,
fill_value: Array | int | float | complex | bool | None,
*args: Array,
kwkeys: list[str],
xp: ModuleType,
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

nargs = len(args) - len(kwkeys)
kwargs = dict(zip(kwkeys, args[nargs:], strict=True))
args = args[:nargs]

if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
return xp.where(
cond,
f1(*args, **kwargs),
f2(*args, **kwargs) if f2 is not None else fill_value,
)

temp1 = f1(*(arr[cond] for arr in args))
temp1 = f1(
*(arr[cond] for arr in args), **{key: val[cond] for key, val in kwargs.items()}
)

if f2 is None:
dtype = xp.result_type(temp1, fill_value)
Expand All @@ -167,7 +192,10 @@ def _apply_where( # numpydoc ignore=PR01,RT01
out = xp.astype(fill_value, dtype, copy=True)
else:
ncond = ~cond
temp2 = f2(*(arr[ncond] for arr in args))
temp2 = f2(
*(arr[ncond] for arr in args),
**{key: val[ncond] for key, val in kwargs.items()},
)
dtype = xp.result_type(temp1, temp2)
out = xp.empty_like(cond, dtype=dtype)
out = at(out, ncond).set(temp2)
Expand Down
42 changes: 30 additions & 12 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def test_device(self, xp: ModuleType, device: Device):
)
@given(
n_arrays=st.integers(min_value=1, max_value=3),
n_kwarrays=st.integers(min_value=1, max_value=3),
rng_seed=st.integers(min_value=1000000000, max_value=9999999999),
dtype=npst.floating_dtypes(sizes=(32, 64)),
p=st.floats(min_value=0, max_value=1),
Expand All @@ -219,6 +220,7 @@ def test_device(self, xp: ModuleType, device: Device):
def test_hypothesis(
self,
n_arrays: int,
n_kwarrays: int,
rng_seed: int,
dtype: np.dtype[Any],
p: float,
Expand All @@ -233,9 +235,13 @@ def test_hypothesis(
):
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")

mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays + 1, min_side=0)
mbs = npst.mutually_broadcastable_shapes(
num_shapes=1 + n_arrays + n_kwarrays, min_side=0
)
input_shapes, _ = data.draw(mbs)
cond_shape, *shapes = input_shapes
cond_shape = input_shapes[0]
shapes = input_shapes[1 : 1 + n_arrays]
kwshapes = input_shapes[1 + n_arrays :]

# cupy/cupy#8382
# https://github.com/jax-ml/jax/issues/26658
Expand All @@ -257,22 +263,34 @@ def test_hypothesis(
for shape in shapes
)

def f1(*args: Array) -> Array:
return cast(Array, sum(args))
kwargs = {
str(n): xp.asarray(
data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements))
)
for n, shape in enumerate(kwshapes)
}
kwkeys = kwargs.keys()

def f1(*args: Array, **kwargs: dict[str, Array]) -> Array:
assert set(kwargs.keys()) == set(kwkeys)
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
return cast(Array, sum(args_kwargs))

def f2(*args: Array) -> Array:
return cast(Array, sum(args) / 2)
def f2(*args: Array, **kwargs: dict[str, Array]) -> Array:
assert set(kwargs.keys()) == set(kwkeys)
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
return cast(Array, sum(args_kwargs) / 2)

rng = np.random.default_rng(rng_seed)
cond = xp.asarray(rng.random(size=cond_shape) > p)

res1 = apply_where(cond, arrays, f1, fill_value=fill_value)
res2 = apply_where(cond, arrays, f1, f2)
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value)
res1 = apply_where(cond, arrays, f1, fill_value=fill_value, kwargs=kwargs)
res2 = apply_where(cond, arrays, f1, f2, kwargs=kwargs)
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value, kwargs=kwargs)

ref1 = xp.where(cond, f1(*arrays), fill_value)
ref2 = xp.where(cond, f1(*arrays), f2(*arrays))
ref3 = xp.where(cond, f1(*arrays), float_fill_value)
ref1 = xp.where(cond, f1(*arrays, **kwargs), fill_value)
ref2 = xp.where(cond, f1(*arrays, **kwargs), f2(*arrays, **kwargs))
ref3 = xp.where(cond, f1(*arrays, **kwargs), float_fill_value)

xp_assert_close(res1, ref1, rtol=2e-16)
xp_assert_equal(res2, ref2)
Expand Down