From f541ad3e059310a8c23b4275a49217dc76c0328b Mon Sep 17 00:00:00 2001 From: Angel Mau Date: Thu, 16 Apr 2026 03:37:59 -0700 Subject: [PATCH] #v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree leaves can also be saved as individual checkpointables. PiperOrigin-RevId: 900631096 --- checkpoint/CHANGELOG.md | 2 + .../v1/_src/handlers/global_registration.py | 10 ++ .../v1/_src/handlers/leaf_handler.py | 114 ++++++++++++++ .../v1/_src/handlers/leaf_handler_test.py | 139 ++++++++++++++++++ .../v1/_src/handlers/pytree_handler.py | 20 +-- .../v1/_src/handlers/registration.py | 79 +++++++--- .../v1/_src/loading/validation.py | 6 + .../experimental/v1/_src/partial/saving.py | 2 +- .../experimental/v1/_src/saving/execution.py | 2 +- .../experimental/v1/_src/saving/saving.py | 10 +- .../experimental/v1/_src/saving/validation.py | 23 ++- .../v1/_src/serialization/registry.py | 20 ++- .../v1/_src/testing/handler_utils.py | 42 +++++- .../v1/_src/testing/save_load_test_base.py | 62 ++++++++ .../v1/_src/training/checkpointer.py | 2 +- 15 files changed, 480 insertions(+), 53 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler_test.py diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 7b2e1b275..d6fc8f52b 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level merging. - Add Patch for Pathways CPU ids. +- #v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree +leaves can also be saved as individual checkpointables. ## [0.11.36] - 2026-04-14 diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py index 8faed03a8..43aa753cd 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py @@ -23,6 +23,7 @@ from typing import Sequence, Type from orbax.checkpoint.experimental.v1._src.handlers import json_handler +from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler from orbax.checkpoint.experimental.v1._src.handlers import proto_handler from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler from orbax.checkpoint.experimental.v1._src.handlers import registration @@ -66,11 +67,20 @@ def _try_register_handler( json_handler.MetricsHandler, checkpoint_layout.METRICS_CHECKPOINTABLE_KEY, ) +# Registration for leaf types that can be treated as distinct checkpointables. +_try_register_handler(leaf_handler.ShardedArrayHandler) +_try_register_handler(leaf_handler.ArrayHandler) +_try_register_handler(leaf_handler.ScalarHandler) +_try_register_handler(leaf_handler.StringHandler) + _try_register_handler( pytree_handler.PyTreeHandler, secondary_typestrs=[ 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler', 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', + handler_types.typestr( + stateful_checkpointable_handler.StatefulCheckpointableHandler + ), ], ) _try_register_handler( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler.py new file mode 100644 index 000000000..4354626ed --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler.py @@ -0,0 +1,114 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for :py:class:`serialization.LeafHandler`. + +This :py:class:`CheckpointableHandler` is a wrapper for checkpointables where +support is already implemented at the PyTree leaf level. +""" + +from typing import Any, Awaitable, TypeVar + +import jax +import numpy as np +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler +from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types +from orbax.checkpoint.experimental.v1._src.path import types as path_types +from orbax.checkpoint.experimental.v1._src.serialization import registry +from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types + + +Leaf = TypeVar('Leaf') +AbstractLeaf = TypeVar('AbstractLeaf') + + +class _LeafHandler(handler_types.CheckpointableHandler[Leaf, AbstractLeaf]): + """Base class for handlers that operate on individual PyTree leaves. + + This handler wraps `PyTreeHandler` to provide support for checkpointables + that are single leaves in a PyTree. + """ + + def __init__(self): + self._context = context_lib.get_context() + + async def save( + self, directory: path_types.PathAwaitingCreation, checkpointable: Leaf + ) -> Awaitable[None]: + return await pytree_handler.PyTreeHandler().save( + directory, [checkpointable] + ) + + async def load( + self, + directory: path_types.Path, + abstract_checkpointable: AbstractLeaf | None = None, + ) -> Awaitable[Leaf]: + if abstract_checkpointable is None: + abstract_pytree = None + else: + abstract_pytree = [abstract_checkpointable] + + background_load = await pytree_handler.PyTreeHandler().load( + directory, abstract_pytree + ) + + async def background_load_wrapper() -> Leaf: + loaded_pytree = await background_load + return loaded_pytree[0] + + return background_load_wrapper() + + async def metadata(self, directory: path_types.Path) -> AbstractLeaf: + pytree_metadata = await pytree_handler.PyTreeHandler().metadata(directory) + return pytree_metadata[0] + + def is_handleable(self, checkpointable: Any) -> bool: + try: + pytree_handler.PyTreeHandler().validate_leaves_handleable( + [checkpointable] + ) + return True + except registry.UnregisteredTypeError: + return False + + def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None: + try: + pytree_handler.PyTreeHandler().validate_abstract_leaves_handleable( + [abstract_checkpointable] + ) + return True + except registry.UnregisteredTypeError: + return False + + +class ShardedArrayHandler( + _LeafHandler[jax.Array, serialization_types.AbstractShardedArray] +): + pass + + +class ArrayHandler(_LeafHandler[np.ndarray, serialization_types.AbstractArray]): + pass + + +class StringHandler(_LeafHandler[str, serialization_types.AbstractString]): + pass + + +class ScalarHandler( + _LeafHandler[serialization_types.Scalar, serialization_types.AbstractScalar] +): + pass diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler_test.py new file mode 100644 index 000000000..5ba9b3553 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler_test.py @@ -0,0 +1,139 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Awaitable + +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import jax +from jax import numpy as jnp +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.arrays import abstract_arrays +from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler +from orbax.checkpoint.experimental.v1._src.path import types as path_types +from orbax.checkpoint.experimental.v1._src.serialization import registry +from orbax.checkpoint.experimental.v1._src.testing import handler_utils as handler_test_utils +from orbax.checkpoint.experimental.v1._src.tree import types as tree_types + + +PathAwaitingCreation = path_types.PathAwaitingCreation +PathLike = path_types.PathLike +Path = path_types.Path +Json = tree_types.JsonType +create_test_handler = handler_test_utils.create_test_handler + +Leaf = leaf_handler.Leaf +AbstractLeaf = leaf_handler.AbstractLeaf + + +async def _run_awaitable(awaitable: Awaitable[Any]) -> Any: + return await awaitable + + +class FooHandler( + leaf_handler._LeafHandler[ + handler_test_utils.Foo, handler_test_utils.AbstractFoo + ] +): + + def is_handleable(self, checkpointable: Any) -> bool: + return isinstance(checkpointable, handler_test_utils.Foo) + + def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None: + return isinstance(abstract_checkpointable, handler_test_utils.AbstractFoo) + + +class LeafHandlerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.directory = epath.Path( + self.create_tempdir(name='checkpointing_test').full_path + ) + self.values_and_abstract_values = { + leaf_handler.ShardedArrayHandler: [ + ( + jnp.arange(8), + abstract_arrays.to_shape_dtype_struct(jnp.arange(8)), + ), + (jnp.arange(8), jax.ShapeDtypeStruct), + ], + leaf_handler.ArrayHandler: [ + (np.arange(8), np.empty_like(np.arange(8))) + ], + leaf_handler.ScalarHandler: [ + (123, int), + (123, 0), + (123.456, float), + (123.456, 0.0), + ], + leaf_handler.StringHandler: [('test', str), ('test', '_')], + } + + def validate_load( + self, + handler: handler_test_utils.TestHandler[Leaf, AbstractLeaf], + value: Leaf, + abstract_value: AbstractLeaf, + directory: Path | None = None, + ): + directory = directory or self.directory + with self.subTest('load_with_abstract'): + restored = handler.load(directory, abstract_value) + test_utils.assert_array_equal(self, value, restored) + with self.subTest('load_without_abstract'): + restored = handler.load(directory) + test_utils.assert_array_equal(self, value, restored) + + @parameterized.parameters( + leaf_handler.ShardedArrayHandler, + leaf_handler.ArrayHandler, + leaf_handler.ScalarHandler, + leaf_handler.StringHandler, + ) + def test_save_load(self, handler_cls): + handler = create_test_handler(handler_cls) + test_cases = self.values_and_abstract_values[handler_cls] + + self.assertFalse(handler.is_handleable(handler_test_utils.Foo(1, 'hi'))) + self.assertFalse( + handler.is_abstract_handleable(handler_test_utils.AbstractFoo()) + ) + + for i, (value, abstract_value) in enumerate(test_cases): + name = str(i) + with self.subTest(f'value={value}, abstract_value={abstract_value}'): + logging.info( + 'Subtest: value=%s, abstract_value=%s', value, abstract_value + ) + self.assertTrue(handler.is_handleable(value)) + self.assertTrue(handler.is_abstract_handleable(abstract_value)) + handler.save(self.directory / name, value) + self.validate_load( + handler, value, abstract_value, directory=self.directory / name + ) + + def test_unregistered_type(self): + handler = create_test_handler(FooHandler) + with self.assertRaises(registry.UnregisteredTypeError): + handler.save(self.directory, handler_test_utils.Foo(1, 'hi')) + + with self.assertRaises(registry.UnregisteredTypeError): + handler.load(self.directory, handler_test_utils.AbstractFoo()) + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index 92e356a65..95167fad1 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -385,7 +385,7 @@ async def save( self, directory: path_types.PathAwaitingCreation, checkpointable: PyTree ) -> Awaitable[None]: start_time = time.time() - self._validate_leaves_handleable(checkpointable) + self.validate_leaves_handleable(checkpointable) commit_futures = await self._handler_impl.async_save( directory.path, @@ -441,7 +441,7 @@ async def load( A awaitable which can be awaited to complete the load operation and obtain a PyTree. """ - self._validate_abstract_leaves_handleable(abstract_checkpointable) + self.validate_abstract_leaves_handleable(abstract_checkpointable) return self._background_load(directory, abstract_checkpointable) async def metadata( @@ -456,7 +456,7 @@ def _unwrap(metadata): return jax.tree.map(_unwrap, v0_metadata) - def _validate_leaves_handleable(self, checkpointable: PyTree): + def validate_leaves_handleable(self, checkpointable: PyTree): missing_leaf_types = set() def _validate_handleable_leaf(leaf: Any): @@ -473,14 +473,14 @@ def _validate_handleable_leaf(leaf: Any): ) if missing_leaf_types: - raise ValueError( + raise registry.UnregisteredTypeError( 'The following leaf types are not registered in the' f' `LeafHandlerRegistry`: [{missing_leaf_types}]. Please register a' ' `LeafHandler` for each type in the `LeafHandlerRegistry` and' ' assign it into the `PyTreeOptions` in the `Context`.' ) - def _validate_abstract_leaves_handleable( + def validate_abstract_leaves_handleable( self, abstract_checkpointable: PyTree ): missing_abstract_leaf_types = set() @@ -499,7 +499,7 @@ def _validate_handleable_leaf(leaf: Any): ) if missing_abstract_leaf_types: - raise ValueError( + raise registry.UnregisteredTypeError( 'The following abstract leaf types are not registered in the' f' `LeafHandlerRegistry`: [{missing_abstract_leaf_types}]. Please' ' register a `LeafHandler` for each type in the' @@ -509,9 +509,11 @@ def _validate_handleable_leaf(leaf: Any): def is_handleable(self, checkpointable: Any) -> bool: try: - # If it's a leaf or an empty pytree container, it's not handleable. - return not jax.tree_util.treedef_is_leaf( - jax.tree.structure(checkpointable) + # If it's a leaf it's not handleable. + tree_structure = jax.tree.structure(checkpointable) + return not ( + jax.tree_util.treedef_is_leaf(tree_structure) + and tree_structure.num_leaves == 1 ) except Exception: # pylint: disable=broad-exception-caught return False diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py index fc9875ce5..77f9bbdeb 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py @@ -72,6 +72,8 @@ CheckpointableHandler = handler_types.CheckpointableHandler RegistryEntry = tuple[Type[CheckpointableHandler], str | None] +DEFAULT_PYTREE_HANDLER_TYPSTR = 'orbax.checkpoint.experimental.v1._src.handlers.pytree_handler.PyTreeHandler' + def add_all( registry: CheckpointableHandlerRegistry, @@ -196,6 +198,10 @@ class NoEntryError(KeyError): """Raised when no entry exists in the registry.""" +class NotHandleableError(ValueError): + """Raised when a checkpointable is not handleable by a handler.""" + + class _DefaultCheckpointableHandlerRegistry(CheckpointableHandlerRegistry): """Default implementation of :py:class:`~.v1.handlers.registration.CheckpointableHandlerRegistry`.""" @@ -515,7 +521,7 @@ def _construct_handler_instance( def _get_possible_handlers( registry: CheckpointableHandlerRegistry, - is_handleable_fn: Callable[[CheckpointableHandler, Any], bool], + is_handleable: Callable[[CheckpointableHandler, Any], bool], checkpointable: Any | None, name: str, ) -> Sequence[CheckpointableHandler]: @@ -539,7 +545,7 @@ def _get_possible_handlers( handler for handler, checkpointable_name in registry_entries if checkpointable_name is None - and is_handleable_fn(handler, checkpointable) + and is_handleable(handler, checkpointable) ] if not possible_handlers: available_handlers = [ @@ -590,25 +596,39 @@ def resolve_handler_for_save( checkpointable: A checkpointable to resolve. name: The name of the checkpointable. - Raises: - NoEntryError: If no compatible - :py:class:`~.v1.handlers.CheckpointableHandler` can be found. - Returns: A :py:class:`~.v1.handlers.CheckpointableHandler` instance. + + Raises: + ValueError: If the checkpointable is None. + NoEntryError: If no compatible + :py:class:`~.v1.handlers.CheckpointableHandler` can be found. """ # If explicitly registered, use that first. if registry.has(name): - return _construct_handler_instance(name, registry.get(name)) + handler = _construct_handler_instance(name, registry.get(name)) + if handler.is_handleable(checkpointable): + return handler + + logging.warning( + 'The explicitly registered handler %s for name="%s" does not handle the' + ' checkpointable item type %s. Please register a handler that is' + ' compatible via `is_handleable()`, or verify that the correct' + ' checkpointable object is passed for this name. Attempting to resolve' + ' a registered compatible handler as fallback.', + type(handler), + name, + type(checkpointable), + ) if checkpointable is None: raise ValueError('checkpointable must not be None for saving.') - def is_handleable_fn(handler: CheckpointableHandler, ckpt: Any) -> bool: + def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool: return handler.is_handleable(ckpt) possible_handlers = _get_possible_handlers( - registry, is_handleable_fn, checkpointable, name + registry, is_handleable, checkpointable, name ) # Prefer the last handler in the absence of any other information. @@ -636,11 +656,6 @@ def resolve_handler_for_load( recently-registered handler, unless abstract_checkpointable is None, in which case raise a NoEntryError. - Raises: - NoEntryError: If no compatible - :py:class:`~.v1.handlers.CheckpointableHandler` - can be found. - Args: registry: The :py:class:`~.v1.handlers.registration.CheckpointableHandlerRegistry` to @@ -654,18 +669,44 @@ def resolve_handler_for_load( Returns: A :py:class:`~.v1.handlers.CheckpointableHandler` instance. + + Raises: + NoEntryError: If no compatible + :py:class:`~.v1.handlers.CheckpointableHandler` + can be found. """ # If explicitly registered, use that first. if registry.has(name): - return _construct_handler_instance(name, registry.get(name)) + handler = _construct_handler_instance(name, registry.get(name)) + if abstract_checkpointable is None: + if ( + handler_typestr == handler_types.typestr(type(handler)) + or not handler_typestr + or handler_typestr in registry.get_secondary_typestrs(type(handler)) + or not registry.get_secondary_typestrs(type(handler)) + ): + return handler + elif handler.is_abstract_handleable(abstract_checkpointable): + return handler + logging.warning( + 'The explicitly registered handler %s for name="%s" does not handle the' + ' abstract checkpointable type %s. Please register a handler that is' + ' compatible via `is_abstract_handleable()`, or verify that the correct' + ' abstract checkpointable object is passed for this name. If you meant' + ' to register a new handler to take over loading, either initialize it' + ' with no secondary typestrs, or ensure that the checkpoint handler' + ' used to save is listed as a secondary typestr on the new handler.' + ' Attempting to resolve a registered compatible handler as fallback.', + type(handler), + name, + abstract_checkpointable, + ) - def is_handleable_fn( - handler: CheckpointableHandler, ckpt: Any - ) -> bool | None: + def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool | None: return handler.is_abstract_handleable(ckpt) possible_handlers = _get_possible_handlers( - registry, is_handleable_fn, abstract_checkpointable, name + registry, is_handleable, abstract_checkpointable, name ) possible_handler_typestrs = [ handler_types.typestr(type(handler)) for handler in possible_handlers diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py index 79f135b61..acb1ffec8 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py @@ -59,6 +59,12 @@ def validate_abstract_checkpointables(abstract_checkpointables): """ if abstract_checkpointables is None: return + if not isinstance(abstract_checkpointables, dict): + raise ValueError( + '`abstract_checkpointables` must be a valid mapping of checkpointable' + ' names to abstract checkpointables to load, but got' + f' {type(abstract_checkpointables)}' + ) if EMPTY_CHECKPOINTABLE_KEY in abstract_checkpointables: raise ValueError( 'Empty string is not supported as a checkpointable name in' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index aef470be3..803e42387 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -64,7 +64,7 @@ async def save( operation_id = f'{operation_id}.{directory.path.name}' # pylint: disable=protected-access - self.handler._validate_leaves_handleable(self.pytree) + self.handler.validate_leaves_handleable(self.pytree) v0_save_args = pytree_handler.create_v0_save_args( self.handler._context, self.pytree diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index 15b47419f..3e9823d98 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -368,7 +368,7 @@ def save_checkpointables_impl( partial_save: bool = False, ) -> async_types.AsyncResponse[None]: """See caller docstrings.""" - validation.validate_abstract_checkpointables(checkpointables) + validation.validate_save_checkpointables(checkpointables) start_time = time.time() event_tracking.OperationRecorder( path, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py index 0c7e481a7..7bcadee72 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py @@ -133,7 +133,13 @@ def save_checkpointables( JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. """ - validation.validate_abstract_checkpointables(checkpointables) + if not isinstance(checkpointables, dict): + raise ValueError( + f'`checkpointables` must be a dict, but got {type(checkpointables)}' + ) + if not checkpointables: + raise ValueError('`checkpointables` must be a non-empty dict.') + validation.validate_save_checkpointables(checkpointables) execution.save_checkpointables_impl( path, checkpointables, @@ -279,7 +285,7 @@ def save_checkpointables_async( An `AsyncResponse` that can be used to block until the save is complete. Blocking can be done using `response.result()`, which returns `None`. """ - validation.validate_abstract_checkpointables(checkpointables) + validation.validate_save_checkpointables(checkpointables) return execution.save_checkpointables_impl( path, checkpointables, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/validation.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/validation.py index 7ca0f68d1..a8cc84919 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/validation.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/validation.py @@ -20,25 +20,32 @@ EMPTY_CHECKPOINTABLE_KEY = checkpoint_layout.EMPTY_CHECKPOINTABLE_KEY -def validate_abstract_checkpointables(abstract_checkpointables): - """Validates the abstract_checkpointables dictionary. +def validate_save_checkpointables(checkpointables): + """Validates the checkpointables dictionary. Args: - abstract_checkpointables: A dictionary of abstract checkpointables. + checkpointables: A dictionary of checkpointables. Raises: - ValueError: If any of the keys in abstract_checkpointables are reserved. + ValueError: If any of the keys in checkpointables are reserved. """ - if abstract_checkpointables is None: - return - if EMPTY_CHECKPOINTABLE_KEY in abstract_checkpointables: + if not checkpointables or not isinstance( + checkpointables, dict + ): + raise ValueError( + '`checkpointables` must be a valid mapping of checkpointable names to' + ' desired checkpointables to save, but got' + f' {type(checkpointables)}' + ) + + if EMPTY_CHECKPOINTABLE_KEY in checkpointables: raise ValueError( 'Empty string is not supported as a checkpointable name in' ' `save_checkpointables`. Each checkpointable name must be a valid' ' non-empty string name.' ) if ( - provided_reserved_keys := abstract_checkpointables.keys() + provided_reserved_keys := checkpointables.keys() & RESERVED_CHECKPOINTABLE_KEYS ): raise ValueError( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py index 17f87ab34..88422b005 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py @@ -30,6 +30,14 @@ import typing_extensions +class AlreadyRegisteredTypeError(ValueError): + """Raised when a leaf is already registered.""" + + +class UnregisteredTypeError(ValueError): + """Raised when a leaf is not registered.""" + + # The standard type, abstract type, and optional typestrs to handler mapping. # The type to abstract type pairs are well defined standard and users should # rarely need to override the pair. @@ -163,7 +171,7 @@ def get( self, leaf_type: type[types.Leaf] ) -> type[types.LeafHandler[types.Leaf, Any]]: if (handler_type := self._try_get(leaf_type)) is None: - raise ValueError( + raise UnregisteredTypeError( f'Unknown Leaf type: "{leaf_type!r}". Must register it with' ' LeafHandlerRegistry.' ) @@ -190,7 +198,7 @@ def get_abstract( abstract_type: type[types.AbstractLeaf], ) -> type[types.LeafHandler[Any, types.AbstractLeaf]]: if (handler_type := self._try_get_abstract(abstract_type)) is None: - raise ValueError( + raise UnregisteredTypeError( f'Unknown AbstractLeaf type: "{abstract_type!r}". Must register it' ' with LeafHandlerRegistry.' ) @@ -245,8 +253,8 @@ def add( secondary identifiers for the handler. Raises: - ValueError: If a duplicate `leaf_type` or conflicting `abstract_type` - mapping exists and `override` is False. + AlreadyRegisteredTypeError: If a duplicate `leaf_type` or conflicting + `abstract_type` mapping exists and `override` is False. """ # Check for exact duplicate registration @@ -294,13 +302,13 @@ def add( else: for e in self._entries: if e.leaf_type == leaf_type: - raise ValueError( + raise AlreadyRegisteredTypeError( f'leaf_type [{leaf_type}] is already handled by ' f'{e.handler_type}. Use override=True to replace its entry. ' f'Registry: {self._entries}' ) if e.abstract_type == abstract_type and e.handler_type != handler_type: - raise ValueError( + raise AlreadyRegisteredTypeError( f'abstract_type[{abstract_type}] is already handled by ' f'{e.handler_type}. Use override=True to replace all ' f'conflicting entries. Registry: {self._entries}' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py index 6c0022c50..2ce52033a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py @@ -18,7 +18,7 @@ import dataclasses import json -from typing import Any, Awaitable, Generic, Sequence, Type, TypeVar +from typing import Any, Awaitable, Generic, Protocol, Sequence, Type, TypeVar import aiofiles from etils import epath @@ -49,33 +49,63 @@ async def _run_awaitable(awaitable: Awaitable[Any]) -> Any: return await awaitable -class _TestHandler(Generic[T, AbstractT]): +class TestHandler(Protocol[T, AbstractT]): """This class facilitates testing of :py:class:`~.v1.handlers.CheckpointableHandler` independently. Use :py:func:`.create_test_handler`. """ + def save(self, directory: Path, checkpointable: T) -> None: + ... + + def save_async(self, directory: Path, checkpointable: T) -> Awaitable[None]: + ... + + def load( + self, path: Path, abstract_checkpointable: AbstractT | None = None + ) -> T: + ... + + def load_async( + self, path: Path, abstract_checkpointable: AbstractT | None = None + ) -> Awaitable[T]: + ... + + def metadata(self, path: Path) -> AbstractT: + ... + + def is_handleable(self, checkpointable: Any) -> bool: + ... + + def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None: + ... + + +class _TestHandler(Generic[T, AbstractT]): + def __init__( self, handler_class: type[CheckpointableHandler[T, AbstractT]], **kwargs ): self._handler: CheckpointableHandler[T, AbstractT] = handler_class(**kwargs) - def save(self, directory: Path, checkpointable: T): + def save(self, directory: Path, checkpointable: T) -> None: path = path_test_utils.PathAwaitingCreationWrapper(directory) awaitable = asyncio_utils.run_sync(self._handler.save(path, checkpointable)) return asyncio_utils.run_sync(_run_awaitable(awaitable)) - def save_async(self, directory: Path, checkpointable: T): + def save_async(self, directory: Path, checkpointable: T) -> Awaitable[None]: path = path_test_utils.PathAwaitingCreationWrapper(directory) return asyncio_utils.run_sync(self._handler.save(path, checkpointable)) - def load(self, path: Path, abstract_checkpointable: AbstractT | None = None): + def load( + self, path: Path, abstract_checkpointable: AbstractT | None = None + ) -> T: awaitable = self.load_async(path, abstract_checkpointable) return asyncio_utils.run_sync(_run_awaitable(awaitable)) def load_async( self, path: Path, abstract_checkpointable: AbstractT | None = None - ): + ) -> Awaitable[T]: return asyncio_utils.run_sync( self._handler.load(path, abstract_checkpointable) ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 60e8f098a..5179d120a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -250,6 +250,54 @@ def test_standard_leaf_types(self, value): else: self.assertEqual(loaded['k'], value) + @parameterized.parameters( + (np.arange(8),), + (2,), + (2.2,), + ('foo',), + (np.asarray(3.14),), + ) + def test_standard_leaf_types_as_checkpointable(self, value): + with self.subTest('save_pytree'): + ocp.save_pytree(self.directory / 'pytree', value) + loaded = ocp.load_pytree(self.directory / 'pytree') + if isinstance(value, np.ndarray): + np.testing.assert_array_equal(loaded, value) + else: + self.assertEqual(loaded, value) + with self.subTest('save_checkpointables'): + ocp.save_checkpointables( + self.directory / 'checkpointables', {'foo': value} + ) + loaded = ocp.load_checkpointables(self.directory / 'checkpointables')[ + 'foo' + ] + if isinstance(value, np.ndarray): + np.testing.assert_array_equal(loaded, value) + else: + self.assertEqual(loaded, value) + + def test_jax_array_as_checkpointable(self): + value = jnp.arange( + 16, + device=jax.sharding.NamedSharding( + jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)), + jax.sharding.PartitionSpec(), + ), + ) + with self.subTest('save_pytree'): + ocp.save_pytree(self.directory / 'pytree', value) + loaded = ocp.load_pytree(self.directory / 'pytree') + test_utils.assert_tree_equal(self, value, loaded) + with self.subTest('save_checkpointables'): + ocp.save_checkpointables( + self.directory / 'checkpointables', {'foo': value} + ) + loaded = ocp.load_checkpointables(self.directory / 'checkpointables')[ + 'foo' + ] + test_utils.assert_tree_equal(self, value, loaded) + def test_jax_array_leaf_types(self): mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) # TODO(cpgaffney): Add support for missing arrays. @@ -284,6 +332,20 @@ def test_jax_array_leaf_types(self): loaded = ocp.load_pytree(self.directory / k) test_utils.assert_tree_equal(self, [v], loaded) + def test_save_unregistered_type_as_pytree(self): + with self.assertRaises(registration.NoEntryError): + ocp.save_pytree(self.directory, handler_utils.Foo(1, 'hi')) + + @parameterized.parameters( + ({},), + ([],), + ('hello',), + (None,), + ) + def test_save_checkpointables_invalid(self, checkpointables): + with self.assertRaises(ValueError): + ocp.save_checkpointables(self.directory, checkpointables) + def test_leaf_change_type(self): mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index d3f385609..cf4fc1210 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -543,7 +543,7 @@ def save_checkpointables_async( StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the target `step` already exists. """ - validation.validate_abstract_checkpointables(checkpointables) + validation.validate_save_checkpointables(checkpointables) if overwrite: logging.info( 'Specified `overwrite`: deleting existing checkpoint %d if it'