Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add colocated runtime helpers for Pathways MTC.
- #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level
merging.
- Support single Jax.random.key item as a PyTree.

## [0.11.36] - 2026-04-14

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ async def async_save(
"""
start_time = time.time()
item = args.item
if not item:
if item is None:
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ async def _deserialize_shardings(
' checkpoint was saved with.'
)
assert info.parent_dir is not None
if info.name:
if info.name is not None:
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(
info.parent_dir.as_posix(), info.name
)
Expand Down
28 changes: 28 additions & 0 deletions checkpoint/orbax/checkpoint/single_host_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,34 @@ def test_save_and_restore_jax_array(self, use_zarr3):
np.testing.assert_array_equal(x, restored_tree['x'])
assert isinstance(restored_tree['x'], jax.Array)

@parameterized.named_parameters(
# using lambda to avoid jax creation during class initialization
('array', lambda: jnp.array([1, 2, 3])),
('random_key', lambda: jax.random.key(1)),
('random_prng', lambda: jax.random.PRNGKey(2)),
('int', lambda: 1),
('float', lambda: 1.0),
('np', lambda: np.array([1, 2, 3])),
)
def test_save_and_restore_single_item(self, value_fn):
value = value_fn()
handler = PyTreeCheckpointHandler(use_ocdbt=True, use_zarr3=True)
handler.save(
self.ckpt_dir,
args=pytree_checkpoint_handler.PyTreeSaveArgs(value),
)
restored_tree = handler.restore(self.ckpt_dir)
test_utils.assert_tree_equal(self, value, restored_tree)

# Special validation for jax.random.key type.
if isinstance(value, jax.Array) and jax.dtypes.issubdtype(
value.dtype, jax.dtypes.prng_key
):
self.assertIsInstance(restored_tree, jax.Array)
self.assertTrue(
jax.dtypes.issubdtype(restored_tree.dtype, jax.dtypes.prng_key)
)

def test_save_and_restore_zarrv3_jax_array_default_chunk_size(self):
handler = PyTreeCheckpointHandler(use_zarr3=True)
key = jax.random.PRNGKey(0)
Expand Down
Loading