Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level
merging.
- Add Patch for Pathways CPU ids.
- #v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree
leaves can also be saved as individual checkpointables.

## [0.11.36] - 2026-04-14

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Sequence, Type

from orbax.checkpoint.experimental.v1._src.handlers import json_handler
from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
from orbax.checkpoint.experimental.v1._src.handlers import registration
Expand Down Expand Up @@ -66,11 +67,20 @@ def _try_register_handler(
json_handler.MetricsHandler,
checkpoint_layout.METRICS_CHECKPOINTABLE_KEY,
)
# Registration for leaf types that can be treated as distinct checkpointables.
_try_register_handler(leaf_handler.ShardedArrayHandler)
_try_register_handler(leaf_handler.ArrayHandler)
_try_register_handler(leaf_handler.ScalarHandler)
_try_register_handler(leaf_handler.StringHandler)

_try_register_handler(
pytree_handler.PyTreeHandler,
secondary_typestrs=[
'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler',
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler',
handler_types.typestr(
stateful_checkpointable_handler.StatefulCheckpointableHandler
),
],
)
_try_register_handler(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrapper for :py:class:`serialization.LeafHandler`.

This :py:class:`CheckpointableHandler` is a wrapper for checkpointables where
support is already implemented at the PyTree leaf level.
"""

from typing import Any, Awaitable, TypeVar

import jax
import numpy as np
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import registry
from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types


Leaf = TypeVar('Leaf')
AbstractLeaf = TypeVar('AbstractLeaf')


class _LeafHandler(handler_types.CheckpointableHandler[Leaf, AbstractLeaf]):
"""Base class for handlers that operate on individual PyTree leaves.

This handler wraps `PyTreeHandler` to provide support for checkpointables
that are single leaves in a PyTree.
"""

def __init__(self):
self._context = context_lib.get_context()

async def save(
self, directory: path_types.PathAwaitingCreation, checkpointable: Leaf
) -> Awaitable[None]:
return await pytree_handler.PyTreeHandler().save(
directory, [checkpointable]
)

async def load(
self,
directory: path_types.Path,
abstract_checkpointable: AbstractLeaf | None = None,
) -> Awaitable[Leaf]:
if abstract_checkpointable is None:
abstract_pytree = None
else:
abstract_pytree = [abstract_checkpointable]

background_load = await pytree_handler.PyTreeHandler().load(
directory, abstract_pytree
)

async def background_load_wrapper() -> Leaf:
loaded_pytree = await background_load
return loaded_pytree[0]

return background_load_wrapper()

async def metadata(self, directory: path_types.Path) -> AbstractLeaf:
pytree_metadata = await pytree_handler.PyTreeHandler().metadata(directory)
return pytree_metadata[0]

def is_handleable(self, checkpointable: Any) -> bool:
try:
pytree_handler.PyTreeHandler().validate_leaves_handleable(
[checkpointable]
)
return True
except registry.UnregisteredTypeError:
return False

def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None:
try:
pytree_handler.PyTreeHandler().validate_abstract_leaves_handleable(
[abstract_checkpointable]
)
return True
except registry.UnregisteredTypeError:
return False


class ShardedArrayHandler(
_LeafHandler[jax.Array, serialization_types.AbstractShardedArray]
):
pass


class ArrayHandler(_LeafHandler[np.ndarray, serialization_types.AbstractArray]):
pass


class StringHandler(_LeafHandler[str, serialization_types.AbstractString]):
pass


class ScalarHandler(
_LeafHandler[serialization_types.Scalar, serialization_types.AbstractScalar]
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Awaitable

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.arrays import abstract_arrays
from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import registry
from orbax.checkpoint.experimental.v1._src.testing import handler_utils as handler_test_utils
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


PathAwaitingCreation = path_types.PathAwaitingCreation
PathLike = path_types.PathLike
Path = path_types.Path
Json = tree_types.JsonType
create_test_handler = handler_test_utils.create_test_handler

Leaf = leaf_handler.Leaf
AbstractLeaf = leaf_handler.AbstractLeaf


async def _run_awaitable(awaitable: Awaitable[Any]) -> Any:
return await awaitable


class FooHandler(
leaf_handler._LeafHandler[
handler_test_utils.Foo, handler_test_utils.AbstractFoo
]
):

def is_handleable(self, checkpointable: Any) -> bool:
return isinstance(checkpointable, handler_test_utils.Foo)

def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None:
return isinstance(abstract_checkpointable, handler_test_utils.AbstractFoo)


class LeafHandlerTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = epath.Path(
self.create_tempdir(name='checkpointing_test').full_path
)
self.values_and_abstract_values = {
leaf_handler.ShardedArrayHandler: [
(
jnp.arange(8),
abstract_arrays.to_shape_dtype_struct(jnp.arange(8)),
),
(jnp.arange(8), jax.ShapeDtypeStruct),
],
leaf_handler.ArrayHandler: [
(np.arange(8), np.empty_like(np.arange(8)))
],
leaf_handler.ScalarHandler: [
(123, int),
(123, 0),
(123.456, float),
(123.456, 0.0),
],
leaf_handler.StringHandler: [('test', str), ('test', '_')],
}

def validate_load(
self,
handler: handler_test_utils.TestHandler[Leaf, AbstractLeaf],
value: Leaf,
abstract_value: AbstractLeaf,
directory: Path | None = None,
):
directory = directory or self.directory
with self.subTest('load_with_abstract'):
restored = handler.load(directory, abstract_value)
test_utils.assert_array_equal(self, value, restored)
with self.subTest('load_without_abstract'):
restored = handler.load(directory)
test_utils.assert_array_equal(self, value, restored)

@parameterized.parameters(
leaf_handler.ShardedArrayHandler,
leaf_handler.ArrayHandler,
leaf_handler.ScalarHandler,
leaf_handler.StringHandler,
)
def test_save_load(self, handler_cls):
handler = create_test_handler(handler_cls)
test_cases = self.values_and_abstract_values[handler_cls]

self.assertFalse(handler.is_handleable(handler_test_utils.Foo(1, 'hi')))
self.assertFalse(
handler.is_abstract_handleable(handler_test_utils.AbstractFoo())
)

for i, (value, abstract_value) in enumerate(test_cases):
name = str(i)
with self.subTest(f'value={value}, abstract_value={abstract_value}'):
logging.info(
'Subtest: value=%s, abstract_value=%s', value, abstract_value
)
self.assertTrue(handler.is_handleable(value))
self.assertTrue(handler.is_abstract_handleable(abstract_value))
handler.save(self.directory / name, value)
self.validate_load(
handler, value, abstract_value, directory=self.directory / name
)

def test_unregistered_type(self):
handler = create_test_handler(FooHandler)
with self.assertRaises(registry.UnregisteredTypeError):
handler.save(self.directory, handler_test_utils.Foo(1, 'hi'))

with self.assertRaises(registry.UnregisteredTypeError):
handler.load(self.directory, handler_test_utils.AbstractFoo())

if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ async def save(
self, directory: path_types.PathAwaitingCreation, checkpointable: PyTree
) -> Awaitable[None]:
start_time = time.time()
self._validate_leaves_handleable(checkpointable)
self.validate_leaves_handleable(checkpointable)

commit_futures = await self._handler_impl.async_save(
directory.path,
Expand Down Expand Up @@ -441,7 +441,7 @@ async def load(
A awaitable which can be awaited to complete the load operation and
obtain a PyTree.
"""
self._validate_abstract_leaves_handleable(abstract_checkpointable)
self.validate_abstract_leaves_handleable(abstract_checkpointable)
return self._background_load(directory, abstract_checkpointable)

async def metadata(
Expand All @@ -456,7 +456,7 @@ def _unwrap(metadata):

return jax.tree.map(_unwrap, v0_metadata)

def _validate_leaves_handleable(self, checkpointable: PyTree):
def validate_leaves_handleable(self, checkpointable: PyTree):
missing_leaf_types = set()

def _validate_handleable_leaf(leaf: Any):
Expand All @@ -473,14 +473,14 @@ def _validate_handleable_leaf(leaf: Any):
)

if missing_leaf_types:
raise ValueError(
raise registry.UnregisteredTypeError(
'The following leaf types are not registered in the'
f' `LeafHandlerRegistry`: [{missing_leaf_types}]. Please register a'
' `LeafHandler` for each type in the `LeafHandlerRegistry` and'
' assign it into the `PyTreeOptions` in the `Context`.'
)

def _validate_abstract_leaves_handleable(
def validate_abstract_leaves_handleable(
self, abstract_checkpointable: PyTree
):
missing_abstract_leaf_types = set()
Expand All @@ -499,7 +499,7 @@ def _validate_handleable_leaf(leaf: Any):
)

if missing_abstract_leaf_types:
raise ValueError(
raise registry.UnregisteredTypeError(
'The following abstract leaf types are not registered in the'
f' `LeafHandlerRegistry`: [{missing_abstract_leaf_types}]. Please'
' register a `LeafHandler` for each type in the'
Expand All @@ -509,9 +509,11 @@ def _validate_handleable_leaf(leaf: Any):

def is_handleable(self, checkpointable: Any) -> bool:
try:
# If it's a leaf or an empty pytree container, it's not handleable.
return not jax.tree_util.treedef_is_leaf(
jax.tree.structure(checkpointable)
# If it's a leaf it's not handleable.
tree_structure = jax.tree.structure(checkpointable)
return not (
jax.tree_util.treedef_is_leaf(tree_structure)
and tree_structure.num_leaves == 1
)
except Exception: # pylint: disable=broad-exception-caught
return False
Expand Down
Loading
Loading