From 1acee7a5ac06eb87c07e592c03468a600ad65a53 Mon Sep 17 00:00:00 2001 From: Matt Haberland Date: Wed, 18 Feb 2026 16:46:18 -0800 Subject: [PATCH] ENH: apply_where: add kwargs support --- src/array_api_extra/_lib/_funcs.py | 38 +++++++++++++++++++++++---- tests/test_funcs.py | 42 +++++++++++++++++++++--------- 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 62ddfa16..460408c2 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -41,6 +41,7 @@ def apply_where( # numpydoc ignore=GL08 f2: Callable[..., Array], /, *, + kwargs: dict[str, Array] | None = None, xp: ModuleType | None = None, ) -> Array: ... @@ -53,6 +54,7 @@ def apply_where( # numpydoc ignore=GL08 /, *, fill_value: Array | complex, + kwargs: dict[str, Array] | None = None, xp: ModuleType | None = None, ) -> Array: ... @@ -65,6 +67,7 @@ def apply_where( # numpydoc ignore=PR01,PR02 /, *, fill_value: Array | complex | None = None, + kwargs: dict[str, Array] | None = None, xp: ModuleType | None = None, ) -> Array: """ @@ -91,6 +94,9 @@ def apply_where( # numpydoc ignore=PR01,PR02 It does not need to be scalar; it needs however to be broadcastable with `cond` and `args`. Mutually exclusive with `f2`. You must provide one or the other. + kwargs : dict of str : Array pairs + Keyword argument(s) to `f1` (and `f2`). Values must be broadcastable with + `cond`. xp : array_namespace, optional The standard-compatible namespace for `cond` and `args`. Default: infer. @@ -129,6 +135,11 @@ def apply_where( # numpydoc ignore=PR01,PR02 args_ = list(args) if isinstance(args, tuple) else [args] del args + kwargs_ = {} if kwargs is None else kwargs + kwkeys = list(kwargs_.keys()) + args_ = [*args_, *kwargs_.values()] + del kwargs + xp = array_namespace(cond, fill_value, *args_) if xp is None else xp if isinstance(fill_value, int | float | complex | NoneType): @@ -139,8 +150,11 @@ def apply_where( # numpydoc ignore=PR01,PR02 if is_dask_namespace(xp): meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp) # map_blocks doesn't descend into tuples of Arrays - return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp) - return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp) + return xp.map_blocks( + _apply_where, cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=meta_xp + ) + + return _apply_where(cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=xp) def _apply_where( # numpydoc ignore=PR01,RT01 @@ -149,15 +163,26 @@ def _apply_where( # numpydoc ignore=PR01,RT01 f2: Callable[..., Array] | None, fill_value: Array | int | float | complex | bool | None, *args: Array, + kwkeys: list[str], xp: ModuleType, ) -> Array: """Helper of `apply_where`. On Dask, this runs on a single chunk.""" + nargs = len(args) - len(kwkeys) + kwargs = dict(zip(kwkeys, args[nargs:], strict=True)) + args = args[:nargs] + if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]: # jax.jit does not support assignment by boolean mask - return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value) + return xp.where( + cond, + f1(*args, **kwargs), + f2(*args, **kwargs) if f2 is not None else fill_value, + ) - temp1 = f1(*(arr[cond] for arr in args)) + temp1 = f1( + *(arr[cond] for arr in args), **{key: val[cond] for key, val in kwargs.items()} + ) if f2 is None: dtype = xp.result_type(temp1, fill_value) @@ -167,7 +192,10 @@ def _apply_where( # numpydoc ignore=PR01,RT01 out = xp.astype(fill_value, dtype, copy=True) else: ncond = ~cond - temp2 = f2(*(arr[ncond] for arr in args)) + temp2 = f2( + *(arr[ncond] for arr in args), + **{key: val[ncond] for key, val in kwargs.items()}, + ) dtype = xp.result_type(temp1, temp2) out = xp.empty_like(cond, dtype=dtype) out = at(out, ncond).set(temp2) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index db64fdb6..dce7c6e1 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -211,6 +211,7 @@ def test_device(self, xp: ModuleType, device: Device): ) @given( n_arrays=st.integers(min_value=1, max_value=3), + n_kwarrays=st.integers(min_value=1, max_value=3), rng_seed=st.integers(min_value=1000000000, max_value=9999999999), dtype=npst.floating_dtypes(sizes=(32, 64)), p=st.floats(min_value=0, max_value=1), @@ -219,6 +220,7 @@ def test_device(self, xp: ModuleType, device: Device): def test_hypothesis( self, n_arrays: int, + n_kwarrays: int, rng_seed: int, dtype: np.dtype[Any], p: float, @@ -233,9 +235,13 @@ def test_hypothesis( ): pytest.xfail(reason="NumPy 1.x dtype promotion for scalars") - mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays + 1, min_side=0) + mbs = npst.mutually_broadcastable_shapes( + num_shapes=1 + n_arrays + n_kwarrays, min_side=0 + ) input_shapes, _ = data.draw(mbs) - cond_shape, *shapes = input_shapes + cond_shape = input_shapes[0] + shapes = input_shapes[1 : 1 + n_arrays] + kwshapes = input_shapes[1 + n_arrays :] # cupy/cupy#8382 # https://github.com/jax-ml/jax/issues/26658 @@ -257,22 +263,34 @@ def test_hypothesis( for shape in shapes ) - def f1(*args: Array) -> Array: - return cast(Array, sum(args)) + kwargs = { + str(n): xp.asarray( + data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements)) + ) + for n, shape in enumerate(kwshapes) + } + kwkeys = kwargs.keys() + + def f1(*args: Array, **kwargs: dict[str, Array]) -> Array: + assert set(kwargs.keys()) == set(kwkeys) + args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values())) + return cast(Array, sum(args_kwargs)) - def f2(*args: Array) -> Array: - return cast(Array, sum(args) / 2) + def f2(*args: Array, **kwargs: dict[str, Array]) -> Array: + assert set(kwargs.keys()) == set(kwkeys) + args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values())) + return cast(Array, sum(args_kwargs) / 2) rng = np.random.default_rng(rng_seed) cond = xp.asarray(rng.random(size=cond_shape) > p) - res1 = apply_where(cond, arrays, f1, fill_value=fill_value) - res2 = apply_where(cond, arrays, f1, f2) - res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value) + res1 = apply_where(cond, arrays, f1, fill_value=fill_value, kwargs=kwargs) + res2 = apply_where(cond, arrays, f1, f2, kwargs=kwargs) + res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value, kwargs=kwargs) - ref1 = xp.where(cond, f1(*arrays), fill_value) - ref2 = xp.where(cond, f1(*arrays), f2(*arrays)) - ref3 = xp.where(cond, f1(*arrays), float_fill_value) + ref1 = xp.where(cond, f1(*arrays, **kwargs), fill_value) + ref2 = xp.where(cond, f1(*arrays, **kwargs), f2(*arrays, **kwargs)) + ref3 = xp.where(cond, f1(*arrays, **kwargs), float_fill_value) xp_assert_close(res1, ref1, rtol=2e-16) xp_assert_equal(res2, ref2)