From af74e15bda0e1a7ea6ee5eefc06d7979104c8014 Mon Sep 17 00:00:00 2001 From: Angel Mau Date: Thu, 16 Apr 2026 02:49:57 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 900612779 --- .../_src/serialization/jax_array_handlers.py | 22 +---------- .../v1/_src/testing/save_load_test_base.py | 37 ++++++++++++------- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 7b75807ee..aa03a16d8 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -136,7 +136,7 @@ def _is_replicated_sharding(sharding: jax.sharding.Sharding) -> bool: else: return False elif isinstance(sharding, jax.sharding.SingleDeviceSharding): - return True + return False else: logging.warning( 'Unsupported sharding type, assuming not replicated: %s', sharding @@ -1094,26 +1094,6 @@ async def serialize( self._ext_metadata = dict() arrays = [] for v, info in zip(values, infos): - if ( - isinstance(v, jax.Array) - and jax.process_count() > 1 - and v.is_fully_addressable - ): - debug_param_info = ( - f'ParamInfo=[name={info.name},value_typestr={info.value_typestr}]' - ) - debug_array = ( - f'jax.Array=[value={v},shape={v.shape},dtype={v.dtype},' - f'sharding={v.sharding},device={v.device}]' - ) - raise ValueError( - f'Cannot serialize host local jax.Array ({debug_param_info},' - f' {debug_array}) in multi-host setting. Arrays like this are' - ' typically obtained using pmap. Consider using' - ' fully_replicated_host_local_array_to_global_array in' - ' orbax/checkpoint/utils.py to convert your arrays into' - f' serializable objects. Array.sharding: {v.sharding}' - ) if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key): # a JAX random key diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 60e8f098a..ebd483233 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -252,10 +252,9 @@ def test_standard_leaf_types(self, value): def test_jax_array_leaf_types(self): mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) - # TODO(cpgaffney): Add support for missing arrays. values = { - # 'simple_array': jnp.arange(16), - # 'single_device_array': jnp.arange(8, device=jax.local_devices()[0]), + 'simple_array': jnp.arange(16), + 'single_device_array': jnp.arange(8, device=jax.local_devices()[0]), 'replicated_array': jnp.arange( 12, device=jax.sharding.NamedSharding( @@ -268,21 +267,33 @@ def test_jax_array_leaf_types(self): mesh, jax.sharding.PartitionSpec(('devices',)) ), ), - # 'single_device_cpu_array': jnp.arange( - # 24, device=jax.local_devices(backend='cpu')[0] - # ), } + # Pathways does not support the CPU backend string. + if not multihost.is_pathways_backend(): + values['single_device_cpu_array'] = jnp.arange( + 24, device=jax.local_devices(backend='cpu')[0] + ) for k, v in values.items(): with self.subTest(k): ocp.save_pytree(self.directory / k, [v]) + + v_expected = v + # On Pathways, we cannot restore into SingleDeviceSharding. + # We must coerce the target structure. + if multihost.is_pathways_backend() and isinstance( + v.sharding, jax.sharding.SingleDeviceSharding + ): + # Coerce to a globally replicated sharding on the current mesh + replicated_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + v_expected = jax.device_put(v, replicated_sharding) + with self.subTest('with_abstract_pytree'): - loaded = ocp.load_pytree(self.directory / k, [as_abstract_type(v)]) - test_utils.assert_tree_equal(self, [v], loaded) - with self.subTest('without_abstract_pytree'): - if multihost.is_pathways_backend(): - self.skipTest('Must provide abstract_pytree for Pathways.') - loaded = ocp.load_pytree(self.directory / k) - test_utils.assert_tree_equal(self, [v], loaded) + loaded = ocp.load_pytree( + self.directory / k, [as_abstract_type(v_expected)] + ) + test_utils.assert_tree_equal(self, [v_expected], loaded) def test_leaf_change_type(self): mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',))