diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index b85c8d4b..e295de67 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -57,12 +57,10 @@ def _maybe_pick_first_pmapped(tree): if jax.local_device_count() == 1: return tree - # Avoid degraded performance under the new jax.pmap. See - # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. - if jax.config.jax_pmap_shmap_merge: - return jax.tree_util.tree_map( - lambda x: x.addressable_shards[0].data.squeeze(0), tree) - return jax.tree_util.tree_map(lambda x: x[0], tree) + # Avoid degraded performance under the new jax.pmap. + return jax.tree_util.tree_map( + lambda x: x.addressable_shards[0].data.squeeze(0), tree + ) @jax.jit diff --git a/clrs/_src/nets.py b/clrs/_src/nets.py index 14fed796..6d246980 100644 --- a/clrs/_src/nets.py +++ b/clrs/_src/nets.py @@ -645,8 +645,7 @@ def __call__(self, features_list: List[_FeaturesChunked], lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]), lstm_state) mp_state.lstm_state = lstm_state - # Avoid degraded performance under the new jax.pmap. See - # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. + # Avoid degraded performance under the new jax.pmap. def _get_first(x): # Handle non-JAX arrays (e.g., numpy arrays) by direct indexing. if not isinstance(x, jax.Array): @@ -660,26 +659,20 @@ def _get_first(x): x.sharding, jax.sharding.SingleDeviceSharding ): return x - # Under the new jax.pmap (jax_pmap_shmap_merge), integer indexing - # into sharded arrays triggers expensive cross-device copies. Handle - # specially to avoid this performance degradation. - if jax.config.jax_pmap_shmap_merge: - # Single-device case: no cross-device copy, safe to index directly. - if len(jax.local_devices()) == 1: - return x[0] - # Fully-replicated arrays have identical data on all shards; - # extract from the first addressable shard to avoid copies. - if x.sharding.is_fully_replicated: - return x.addressable_shards[0].data - # For non-replicated sharded arrays, get data from the first shard. - # If the shard has a leading dimension of 1 (from the pmap batch - # axis), squeeze it out to match the expected shape. - shard_data = x.addressable_shards[0].data - if shard_data.shape and shard_data.shape[0] == 1: - return shard_data.squeeze(0) - return shard_data - # Legacy pmap path: direct indexing is safe and efficient. - return x[0] + # Single-device case: no cross-device copy, safe to index directly. + if len(jax.local_devices()) == 1: + return x[0] + # Fully-replicated arrays have identical data on all shards; + # extract from the first addressable shard to avoid copies. + if x.sharding.is_fully_replicated: + return x.addressable_shards[0].data + # For non-replicated sharded arrays, get data from the first shard. + # If the shard has a leading dimension of 1 (from the pmap batch + # axis), squeeze it out to match the expected shape. + shard_data = x.addressable_shards[0].data + if shard_data.shape and shard_data.shape[0] == 1: + return shard_data.squeeze(0) + return shard_data mp_state.inputs = jax.tree_util.tree_map(_get_first, inputs) mp_state.hints = jax.tree_util.tree_map(_get_first, hints)