From 7da2538738e75be3a2ec3cf52c9a1708c594f384 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Wed, 1 Jul 2026 07:29:46 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 941083537 --- drjax/_src/primitives.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index fb7c25f..c158bdc 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -48,7 +48,7 @@ def _define_broadcast_prim( def broadcast_prim_fn(x, *, mesh=None): return broadcast_p.bind(x, mesh=mesh) - return (broadcast_p, broadcast_prim_fn) + return (broadcast_p, broadcast_prim_fn) # pyrefly: ignore[bad-return] def _register_broadcast_impls( @@ -249,7 +249,7 @@ def _batch_agg(xs, batched_shape): # Certain jax libs can silently insert the 'batching' dim 'all the way at # the front'; we are about to destroy the front axis by agging, so move # that puppy to the back. Tell the rest of JAX what happened here. - xs = jnp.moveaxis(*xs, *batched_shape, -1) + xs = jnp.moveaxis(*xs, *batched_shape, -1) # pyrefly: ignore[bad-argument-count] return agg_prim_fn(xs), len(xs.shape) - 2 # Make sure this can also be batched / mapped. This happens when dispatching @@ -301,7 +301,7 @@ def broadcast_array_eval(x, *, mesh): _register_broadcast_impls( broadcast_p, broadcast_prim_fn, - broadcast_array_eval, + broadcast_array_eval, # pyrefly: ignore[bad-argument-type] sum_prim_fn, placement_str, n_elements, @@ -317,7 +317,7 @@ def broadcast_array_eval(x, *, mesh): mean_p, mean_prim_fn, impl_defs.mean_from_placement, - lambda x: jnp.divide(broadcast_prim_fn(x), n_elements), + lambda x: jnp.divide(broadcast_prim_fn(x), n_elements), # pyrefly: ignore[bad-argument-type] )