diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 7b75807ee..a12b7e238 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -136,7 +136,7 @@ def _is_replicated_sharding(sharding: jax.sharding.Sharding) -> bool: else: return False elif isinstance(sharding, jax.sharding.SingleDeviceSharding): - return True + return False else: logging.warning( 'Unsupported sharding type, assuming not replicated: %s', sharding @@ -1094,26 +1094,29 @@ async def serialize( self._ext_metadata = dict() arrays = [] for v, info in zip(values, infos): - if ( - isinstance(v, jax.Array) - and jax.process_count() > 1 - and v.is_fully_addressable - ): - debug_param_info = ( - f'ParamInfo=[name={info.name},value_typestr={info.value_typestr}]' - ) - debug_array = ( - f'jax.Array=[value={v},shape={v.shape},dtype={v.dtype},' - f'sharding={v.sharding},device={v.device}]' - ) - raise ValueError( - f'Cannot serialize host local jax.Array ({debug_param_info},' - f' {debug_array}) in multi-host setting. Arrays like this are' - ' typically obtained using pmap. Consider using' - ' fully_replicated_host_local_array_to_global_array in' - ' orbax/checkpoint/utils.py to convert your arrays into' - f' serializable objects. Array.sharding: {v.sharding}' - ) + if isinstance(v, jax.Array): + if jax.process_count() > 1 or multihost.is_pathways_backend(): + if isinstance(v.sharding, jax.sharding.SingleDeviceSharding): + raise ValueError( + 'Orbax does not support saving arrays with' + ' SingleDeviceSharding in multi-host environments.' + ) + if jax.process_count() > 1 and v.is_fully_addressable: + debug_param_info = ( + f'ParamInfo=[name={info.name},value_typestr={info.value_typestr}]' + ) + debug_array = ( + f'jax.Array=[value={v},shape={v.shape},dtype={v.dtype},' + f'sharding={v.sharding},device={v.device}]' + ) + raise ValueError( + f'Cannot serialize host local jax.Array ({debug_param_info},' + f' {debug_array}) in multi-host setting. Arrays like this are' + ' typically obtained using pmap. Consider using' + ' fully_replicated_host_local_array_to_global_array in' + ' orbax/checkpoint/utils.py to convert your arrays into' + f' serializable objects. Array.sharding: {v.sharding}' + ) if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key): # a JAX random key diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 60e8f098a..c869af346 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -252,10 +252,9 @@ def test_standard_leaf_types(self, value): def test_jax_array_leaf_types(self): mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) - # TODO(cpgaffney): Add support for missing arrays. values = { - # 'simple_array': jnp.arange(16), - # 'single_device_array': jnp.arange(8, device=jax.local_devices()[0]), + 'simple_array': jnp.arange(16), + 'single_device_array': jnp.arange(8, device=jax.local_devices()[0]), 'replicated_array': jnp.arange( 12, device=jax.sharding.NamedSharding( @@ -268,12 +267,28 @@ def test_jax_array_leaf_types(self): mesh, jax.sharding.PartitionSpec(('devices',)) ), ), - # 'single_device_cpu_array': jnp.arange( - # 24, device=jax.local_devices(backend='cpu')[0] - # ), } + if not multihost.is_pathways_backend(): + values['single_device_cpu_array'] = jnp.arange( + 24, device=jax.local_devices(backend='cpu')[0] + ) + for k, v in values.items(): with self.subTest(k): + if ( + jax.process_count() > 1 + or multihost.is_pathways_backend() + ) and k in [ + 'simple_array', + 'single_device_array', + 'single_device_cpu_array', + ]: + with self.assertRaisesRegex( + ValueError, 'with SingleDeviceSharding' + ): + ocp.save_pytree(self.directory / k, [v]) + continue + ocp.save_pytree(self.directory / k, [v]) with self.subTest('with_abstract_pytree'): loaded = ocp.load_pytree(self.directory / k, [as_abstract_type(v)])