Skip to content

[pmap] Make clrs nets.py _get_first more robust under jax_pmap_shmap_merge#177

Closed
copybara-service[bot] wants to merge 0 commit intomasterfrom
test_864517583
Closed

[pmap] Make clrs nets.py _get_first more robust under jax_pmap_shmap_merge#177
copybara-service[bot] wants to merge 0 commit intomasterfrom
test_864517583

Conversation

@copybara-service
Copy link

[pmap] Make clrs nets.py _get_first more robust under jax_pmap_shmap_merge

With jax_pmap_shmap_merge=True, the code in nets.py's NetChunked.call
was calling addressable_shards on inputs and hints without checking if they
are JAX arrays first. When numpy arrays are passed (e.g., during init), this
causes AttributeError.

This fix adds a robust _get_first helper that:

  1. Checks if input is a JAX Array before accessing addressable_shards
  2. For numpy arrays, uses simple x[0] indexing
  3. For JAX arrays, handles 0-d arrays, SingleDeviceSharding, and properly
    extracts data from replicated shards

@copybara-service copybara-service bot closed this Feb 3, 2026
@copybara-service copybara-service bot deleted the test_864517583 branch February 3, 2026 16:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants