Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _is_replicated_sharding(sharding: jax.sharding.Sharding) -> bool:
else:
return False
elif isinstance(sharding, jax.sharding.SingleDeviceSharding):
return True
return False
else:
logging.warning(
'Unsupported sharding type, assuming not replicated: %s', sharding
Expand Down Expand Up @@ -1094,26 +1094,6 @@ async def serialize(
self._ext_metadata = dict()
arrays = []
for v, info in zip(values, infos):
if (
isinstance(v, jax.Array)
and jax.process_count() > 1
and v.is_fully_addressable
):
debug_param_info = (
f'ParamInfo=[name={info.name},value_typestr={info.value_typestr}]'
)
debug_array = (
f'jax.Array=[value={v},shape={v.shape},dtype={v.dtype},'
f'sharding={v.sharding},device={v.device}]'
)
raise ValueError(
f'Cannot serialize host local jax.Array ({debug_param_info},'
f' {debug_array}) in multi-host setting. Arrays like this are'
' typically obtained using pmap. Consider using'
' fully_replicated_host_local_array_to_global_array in'
' orbax/checkpoint/utils.py to convert your arrays into'
f' serializable objects. Array.sharding: {v.sharding}'
)

if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
# a JAX random key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,9 @@ def test_standard_leaf_types(self, value):

def test_jax_array_leaf_types(self):
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',))
# TODO(cpgaffney): Add support for missing arrays.
values = {
# 'simple_array': jnp.arange(16),
# 'single_device_array': jnp.arange(8, device=jax.local_devices()[0]),
'simple_array': jnp.arange(16),
'single_device_array': jnp.arange(8, device=jax.local_devices()[0]),
'replicated_array': jnp.arange(
12,
device=jax.sharding.NamedSharding(
Expand All @@ -268,21 +267,33 @@ def test_jax_array_leaf_types(self):
mesh, jax.sharding.PartitionSpec(('devices',))
),
),
# 'single_device_cpu_array': jnp.arange(
# 24, device=jax.local_devices(backend='cpu')[0]
# ),
}
# Pathways does not support the CPU backend string.
if not multihost.is_pathways_backend():
values['single_device_cpu_array'] = jnp.arange(
24, device=jax.local_devices(backend='cpu')[0]
)
for k, v in values.items():
with self.subTest(k):
ocp.save_pytree(self.directory / k, [v])

v_expected = v
# On Pathways, we cannot restore into SingleDeviceSharding.
# We must coerce the target structure.
if multihost.is_pathways_backend() and isinstance(
v.sharding, jax.sharding.SingleDeviceSharding
):
# Coerce to a globally replicated sharding on the current mesh
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
v_expected = jax.device_put(v, replicated_sharding)

with self.subTest('with_abstract_pytree'):
loaded = ocp.load_pytree(self.directory / k, [as_abstract_type(v)])
test_utils.assert_tree_equal(self, [v], loaded)
with self.subTest('without_abstract_pytree'):
if multihost.is_pathways_backend():
self.skipTest('Must provide abstract_pytree for Pathways.')
loaded = ocp.load_pytree(self.directory / k)
test_utils.assert_tree_equal(self, [v], loaded)
loaded = ocp.load_pytree(
self.directory / k, [as_abstract_type(v_expected)]
)
test_utils.assert_tree_equal(self, [v_expected], loaded)

def test_leaf_change_type(self):
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',))
Expand Down
Loading