Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,10 @@
def _maybe_pick_first_pmapped(tree):
if jax.local_device_count() == 1:
return tree
# Avoid degraded performance under the new jax.pmap. See
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
if jax.config.jax_pmap_shmap_merge:
return jax.tree_util.tree_map(
lambda x: x.addressable_shards[0].data.squeeze(0), tree)
return jax.tree_util.tree_map(lambda x: x[0], tree)
# Avoid degraded performance under the new jax.pmap.
return jax.tree_util.tree_map(
lambda x: x.addressable_shards[0].data.squeeze(0), tree
)


@jax.jit
Expand Down
37 changes: 15 additions & 22 deletions clrs/_src/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,7 @@ def __call__(self, features_list: List[_FeaturesChunked],
lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]),
lstm_state)
mp_state.lstm_state = lstm_state
# Avoid degraded performance under the new jax.pmap. See
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
# Avoid degraded performance under the new jax.pmap.
def _get_first(x):
# Handle non-JAX arrays (e.g., numpy arrays) by direct indexing.
if not isinstance(x, jax.Array):
Expand All @@ -660,26 +659,20 @@ def _get_first(x):
x.sharding, jax.sharding.SingleDeviceSharding
):
return x
# Under the new jax.pmap (jax_pmap_shmap_merge), integer indexing
# into sharded arrays triggers expensive cross-device copies. Handle
# specially to avoid this performance degradation.
if jax.config.jax_pmap_shmap_merge:
# Single-device case: no cross-device copy, safe to index directly.
if len(jax.local_devices()) == 1:
return x[0]
# Fully-replicated arrays have identical data on all shards;
# extract from the first addressable shard to avoid copies.
if x.sharding.is_fully_replicated:
return x.addressable_shards[0].data
# For non-replicated sharded arrays, get data from the first shard.
# If the shard has a leading dimension of 1 (from the pmap batch
# axis), squeeze it out to match the expected shape.
shard_data = x.addressable_shards[0].data
if shard_data.shape and shard_data.shape[0] == 1:
return shard_data.squeeze(0)
return shard_data
# Legacy pmap path: direct indexing is safe and efficient.
return x[0]
# Single-device case: no cross-device copy, safe to index directly.
if len(jax.local_devices()) == 1:
return x[0]
# Fully-replicated arrays have identical data on all shards;
# extract from the first addressable shard to avoid copies.
if x.sharding.is_fully_replicated:
return x.addressable_shards[0].data
# For non-replicated sharded arrays, get data from the first shard.
# If the shard has a leading dimension of 1 (from the pmap batch
# axis), squeeze it out to match the expected shape.
shard_data = x.addressable_shards[0].data
if shard_data.shape and shard_data.shape[0] == 1:
return shard_data.squeeze(0)
return shard_data

mp_state.inputs = jax.tree_util.tree_map(_get_first, inputs)
mp_state.hints = jax.tree_util.tree_map(_get_first, hints)
Expand Down
Loading