diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 9fba7c0d9..a6c064c7d 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add colocated runtime helpers for Pathways MTC. - #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level merging. +- Support single Jax.random.key item as a PyTree. ## [0.11.36] - 2026-04-14 diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index ce6f90fb5..215d06a1b 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -648,7 +648,7 @@ async def async_save( """ start_time = time.time() item = args.item - if not item: + if item is None: raise ValueError('Found empty item.') save_args = args.save_args ocdbt_target_data_file_size = args.ocdbt_target_data_file_size diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 7b75807ee..111a8f623 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -723,7 +723,7 @@ async def _deserialize_shardings( ' checkpoint was saved with.' ) assert info.parent_dir is not None - if info.name: + if info.name is not None: tspec_sharding = ts_utils.get_sharding_tensorstore_spec( info.parent_dir.as_posix(), info.name ) diff --git a/checkpoint/orbax/checkpoint/single_host_test.py b/checkpoint/orbax/checkpoint/single_host_test.py index c192a8923..7a18863ce 100644 --- a/checkpoint/orbax/checkpoint/single_host_test.py +++ b/checkpoint/orbax/checkpoint/single_host_test.py @@ -66,6 +66,34 @@ def test_save_and_restore_jax_array(self, use_zarr3): np.testing.assert_array_equal(x, restored_tree['x']) assert isinstance(restored_tree['x'], jax.Array) + @parameterized.named_parameters( + # using lambda to avoid jax creation during class initialization + ('array', lambda: jnp.array([1, 2, 3])), + ('random_key', lambda: jax.random.key(1)), + ('random_prng', lambda: jax.random.PRNGKey(2)), + ('int', lambda: 1), + ('float', lambda: 1.0), + ('np', lambda: np.array([1, 2, 3])), + ) + def test_save_and_restore_single_item(self, value_fn): + value = value_fn() + handler = PyTreeCheckpointHandler(use_ocdbt=True, use_zarr3=True) + handler.save( + self.ckpt_dir, + args=pytree_checkpoint_handler.PyTreeSaveArgs(value), + ) + restored_tree = handler.restore(self.ckpt_dir) + test_utils.assert_tree_equal(self, value, restored_tree) + + # Special validation for jax.random.key type. + if isinstance(value, jax.Array) and jax.dtypes.issubdtype( + value.dtype, jax.dtypes.prng_key + ): + self.assertIsInstance(restored_tree, jax.Array) + self.assertTrue( + jax.dtypes.issubdtype(restored_tree.dtype, jax.dtypes.prng_key) + ) + def test_save_and_restore_zarrv3_jax_array_default_chunk_size(self): handler = PyTreeCheckpointHandler(use_zarr3=True) key = jax.random.PRNGKey(0)