From 9bc25aaa1416d6e703878ab9db85d97ed1a2dc65 Mon Sep 17 00:00:00 2001 From: Angel Mau Date: Sun, 3 May 2026 17:41:51 -0700 Subject: [PATCH] #v1 Allow a context to be default-configured for all `Checkpointer` operations. PiperOrigin-RevId: 909720210 --- checkpoint/CHANGELOG.md | 1 + .../experimental/v1/_src/context/context.py | 6 + .../v1/_src/context/context_test.py | 31 ++++ .../experimental/v1/_src/saving/saving.py | 3 +- .../v1/_src/training/checkpointer.py | 135 ++++++++++-------- .../_src/training/checkpointer_test_base.py | 65 ++++++++- 6 files changed, 175 insertions(+), 66 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 820923ed6..600c4d97e 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -21,6 +21,7 @@ merging. - #v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree leaves can also be saved as individual checkpointables. - Move MTC files to multi_tier_checkpointing and use local checkpoint engine +- #v1 Allow a context to be default-configured for all `Checkpointer` operations. ### Changed diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py index 7210470a8..eb27f71df 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py @@ -60,6 +60,12 @@ class Context(epy.ContextManager): that thread will not inherit the context and will fall back to default settings. + Note: When testing or mixing checkpointer instances and free functions, + explicitly wrap free functions inside their own `with ocp.Context(...)` block, + or pass explicit contexts to Checkpointer constructors, to ensure each actor + receives its correct active configuration independent of the surrounding + context. + Example: Basic usage and explicit inheritance:: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py index 592bcc7a3..29c3d3296 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py @@ -43,6 +43,37 @@ def test_default_context(self): ctx = fake_checkpoint_operation() self.assertEqual(ctx.array_options, ArrayOptions()) + def test_get_context_with_default(self): + default_ctx = ocp.Context( + array_options=ArrayOptions(saving=ArrayOptions.Saving(use_ocdbt=False)) + ) + custom_ctx = ocp.Context( + array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) + ) + + with self.subTest("no context set, no default provided"): + ctx = context_lib.get_context() + self.assertEqual(ctx.array_options, ArrayOptions()) + + with self.subTest("no context set, default provided"): + ctx = context_lib.get_context(default=default_ctx) + self.assertIs(ctx, default_ctx) + + with self.subTest("context IS set, no default provided"): + with custom_ctx: + ctx = context_lib.get_context() + self.assertIs(ctx, custom_ctx) + + with self.subTest("context IS set, default provided"): + with custom_ctx: + ctx = context_lib.get_context(default=default_ctx) + self.assertIs(ctx, custom_ctx) + self.assertIsNot(ctx, default_ctx) + + with self.subTest("no context set, default=None provided"): + ctx = context_lib.get_context(default=None) + self.assertEqual(ctx.array_options, ArrayOptions()) + def test_custom_context(self): with ocp.Context( array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py index 9f62e7843..883c7019a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py @@ -299,7 +299,6 @@ def get_v0_checkpointer_and_args( checkpointables: dict[str, Any], *, metrics: tree_types.JsonType | None = None, - context: context_lib.Context, ) -> tuple[ async_checkpointer.AsyncCheckpointer, composite_checkpoint_handler.CompositeArgs, @@ -309,11 +308,11 @@ def get_v0_checkpointer_and_args( Args: checkpointables: A dictionary of checkpointables. metrics: Optional metrics to add to the checkpointables. - context: The Orbax context. Returns: A tuple containing the V0 Checkpointer and Args. """ + context = context_lib.get_context() checkpointables = execution.add_internal_checkpointables( checkpointables, context=context, metrics=metrics ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index 147bb6ac9..3789a053c 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -84,6 +84,7 @@ def __init__( self, directory: path_types.PathLike, *, + context: context_lib.Context | None = None, save_decision_policy: ( save_decision_policies.SaveDecisionPolicy | None ) = None, @@ -146,6 +147,8 @@ def __init__( Args: directory: The root directory where checkpoints are stored. The directory will be created if it does not exist. + context: A :py:class:`~orbax.checkpoint.v1.Context` object that will be + used to wrap all function calls for this `Checkpointer`. save_decision_policy: A policy used to determine when a checkpoint should be saved. If not provided, the `Checkpointer` saves as often as possible by default (assuming no checkpoint is currently being saved), and saves @@ -167,7 +170,7 @@ def __init__( checkpoint steps present and checkpoint info properties like `time` and `metrics` are not needed. """ - context = context_lib.get_context() + self._context = context or context_lib.get_context() default_save_decision_policy = save_decision_policies.AnySavePolicy([ save_decision_policies.InitialSavePolicy(), @@ -188,11 +191,11 @@ def __init__( cleanup_tmp_directories=cleanup_tmp_directories, lightweight_initialize=lightweight_initialize, max_to_keep=None, # Unlimited. - todelete_full_path=context.deletion_options.gcs_deletion_options.todelete_full_path, - async_options=context.async_options.v0(), - file_options=context.file_options.v0(), - multiprocessing_options=context.multiprocessing_options.v0(), - temporary_path_class=context.file_options.temporary_path_class, + todelete_full_path=self._context.deletion_options.gcs_deletion_options.todelete_full_path, + async_options=self._context.async_options.v0(), + file_options=self._context.file_options.v0(), + multiprocessing_options=self._context.multiprocessing_options.v0(), + temporary_path_class=self._context.file_options.temporary_path_class, # Prevent the checkpoint manager from writing metrics on its own. This # class will take responsibility for writing metrics. prevent_write_metrics=True, @@ -242,7 +245,6 @@ class that `Checkpointer` is unaware of. Note that doing this is Returns: A list of checkpoints, sorted ascending by step. """ - infos = sorted(self._manager._checkpoints, key=lambda info: info.step) # pylint: disable=protected-access return [ CheckpointMetadata[None]( @@ -273,8 +275,9 @@ def _resolve_existing_checkpoint( def should_save(self, step: int) -> bool: """Returns whether a checkpoint should be saved at the given step.""" - step = _resolve_integer_step(step) - return self._manager.should_save(step) + with context_lib.get_context(self._context): + step = _resolve_integer_step(step) + return self._manager.should_save(step) def save_pytree( self, @@ -550,6 +553,7 @@ def save_checkpointables_async( StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the target `step` already exists. """ + context = context_lib.get_context(self._context) validation.validate_save_checkpointables(checkpointables) if overwrite: logging.info( @@ -561,21 +565,22 @@ def save_checkpointables_async( self._manager.delete(step) except FileNotFoundError: pass - elif step in [c.step for c in self.checkpoints]: + elif any(c.step == step for c in self.checkpoints): raise errors.StepAlreadyExistsError(f'Step {step} already exists.') - checkpointer, args = saving.get_v0_checkpointer_and_args( - checkpointables, metrics=metrics, context=context_lib.get_context() - ) - self._manager._checkpointer = checkpointer # pylint: disable=protected-access - saved = self._manager.save( - step, - args=args, - metrics=metrics, - force=force, - custom_metadata=custom_metadata, - ) - return _AsyncSaveResponse(self._manager, saved) + with context: + checkpointer, args = saving.get_v0_checkpointer_and_args( + checkpointables, metrics=metrics + ) + self._manager._checkpointer = checkpointer # pylint: disable=protected-access + saved = self._manager.save( + step, + args=args, + metrics=metrics, + force=force, + custom_metadata=custom_metadata, + ) + return _AsyncSaveResponse(self._manager, saved) def load_pytree( self, @@ -752,11 +757,12 @@ def load_checkpointables( returns only the keys specified in that dict, otherwise returns all keys saved with `save_checkpointables`. """ - step = self._resolve_existing_checkpoint(step).step - return loading.load_checkpointables( - self.directory / self._step_name_format.build_name(step), - abstract_checkpointables, - ) + with context_lib.get_context(self._context): + step = self._resolve_existing_checkpoint(step).step + return loading.load_checkpointables( + self.directory / self._step_name_format.build_name(step), + abstract_checkpointables, + ) def load_pytree_async( self, @@ -797,23 +803,24 @@ def pytree_metadata( :py:class:`.PyTreeMetadata`, along with checkpoint timestamp and metrics information. """ - checkpoint = self._resolve_existing_checkpoint(step) - del step - checkpoint_metadata = metadata_loading.pytree_metadata( - self._manager.directory - / self._step_name_format.build_name(checkpoint.step) - ) - return training_metadata_types.CheckpointMetadata[ - metadata_types.PyTreeMetadata - ]( - step=checkpoint.step, - path=checkpoint_metadata.path, - metadata=checkpoint_metadata.metadata, - init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, - commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, - custom_metadata=checkpoint_metadata.custom_metadata, - metrics=checkpoint.metrics, - ) + with context_lib.get_context(self._context): + checkpoint = self._resolve_existing_checkpoint(step) + del step + checkpoint_metadata = metadata_loading.pytree_metadata( + self._manager.directory + / self._step_name_format.build_name(checkpoint.step) + ) + return training_metadata_types.CheckpointMetadata[ + metadata_types.PyTreeMetadata + ]( + step=checkpoint.step, + path=checkpoint_metadata.path, + metadata=checkpoint_metadata.metadata, + init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, + commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, + custom_metadata=checkpoint_metadata.custom_metadata, + metrics=checkpoint.metrics, + ) def checkpointables_metadata( self, step: int | CheckpointMetadata | None = None @@ -834,29 +841,31 @@ def checkpointables_metadata( describing the checkpointables, along with checkpoint timestamp and metrics information. """ - checkpoint = self._resolve_existing_checkpoint(step) - del step - checkpoint_metadata = metadata_loading.checkpointables_metadata( - self._manager.directory - / self._step_name_format.build_name(checkpoint.step) - ) - return training_metadata_types.CheckpointMetadata[dict[str, Any]]( - step=checkpoint.step, - path=checkpoint_metadata.path, - metadata=checkpoint_metadata.metadata, - init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, - commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, - custom_metadata=checkpoint_metadata.custom_metadata, - metrics=checkpoint.metrics, - ) + with context_lib.get_context(self._context): + checkpoint = self._resolve_existing_checkpoint(step) + del step + checkpoint_metadata = metadata_loading.checkpointables_metadata( + self._manager.directory + / self._step_name_format.build_name(checkpoint.step) + ) + return training_metadata_types.CheckpointMetadata[dict[str, Any]]( + step=checkpoint.step, + path=checkpoint_metadata.path, + metadata=checkpoint_metadata.metadata, + init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, + commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, + custom_metadata=checkpoint_metadata.custom_metadata, + metrics=checkpoint.metrics, + ) def root_metadata( self, ) -> training_metadata_types.RootMetadata: - metadata = self._manager.metadata(None) - return RootMetadata( - directory=self.directory, custom_metadata=metadata.custom_metadata - ) + with context_lib.get_context(self._context): + metadata = self._manager.metadata(None) + return RootMetadata( + directory=self.directory, custom_metadata=metadata.custom_metadata + ) def reload(self): """Reloads internal properties from the root directory. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index 27e6a371d..3d952f84d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -121,6 +121,8 @@ def test_save_restore_pytree(self): test_utils.assert_tree_equal(self, self.pytree, loaded) with self.subTest('without_abstract_pytree'): + if multihost.is_pathways_backend(): + self.skipTest('Must provide abstract_pytree for Pathways.') loaded = checkpointer.load_pytree(0) test_utils.assert_tree_equal(self, self.pytree, loaded) @@ -435,6 +437,8 @@ def test_custom_checkpointables(self): self.save_checkpointables(checkpointer, 0, checkpointables) with self.subTest('load'): + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present in Pathways.') loaded = checkpointer.load_checkpointables(0) self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) test_utils.assert_tree_equal( @@ -443,7 +447,16 @@ def test_custom_checkpointables(self): self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) with self.subTest('load_with_free_function'): - loaded = ocp.load_checkpointables(self.directory / '0') + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present in Pathways.') + checkpointables_options = ( + ocp.options.CheckpointablesOptions.create_with_handlers( + foo=handler_utils.FooHandler, + bar=handler_utils.BarHandler, + ) + ) + with ocp.Context(checkpointables_options=checkpointables_options): + loaded = ocp.load_checkpointables(self.directory / '0') self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) test_utils.assert_tree_equal( self, checkpointables['pytree'], loaded['pytree'] @@ -462,6 +475,8 @@ def test_custom_checkpointables(self): self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) with self.subTest('load_with_abstract_checkpointables_none_values'): + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present in Pathways.') abstract_checkpointables = { 'pytree': None, 'foo': None, @@ -683,6 +698,54 @@ def test_gcs_deletion_options(self): ) + def test_context_constructor_override(self): + ctx1 = ocp.Context( + array_options=ocp.options.ArrayOptions( + saving=ocp.options.ArrayOptions.Saving(use_ocdbt=False) + ), + pytree_options=ocp.options.PyTreeOptions( + loading=ocp.options.PyTreeOptions.Loading(partial_load=True) + ), + ) + checkpointer = Checkpointer(self.directory, context=ctx1) + self.enter_context(checkpointer) + self.save_pytree(checkpointer, 0, self.pytree) + + with self.subTest('constructor_override_ocdbt'): + # Default use_ocdbt is True, so set to False to prove constructor arg is + # used. + pytree_dir = self.directory / '0' / 'pytree' + self.assertFalse( + (pytree_dir / 'manifest.ocdbt').exists(), + f'Expected NO manifest.ocdbt under {pytree_dir}', + ) + + with self.subTest('constructor_override_partial_load'): + loaded = checkpointer.load_pytree(0, self.abstract_pytree) + test_utils.assert_tree_equal(self, self.pytree, loaded) + + # Test partial load override. + partial_abstract = {'jax_array': self.abstract_pytree['jax_array']} + loaded_partial = checkpointer.load_pytree(0, partial_abstract) + expected_pytree = {'jax_array': self.pytree['jax_array']} + test_utils.assert_tree_equal(self, expected_pytree, loaded_partial) + + with self.subTest('local_context_override'): + # Override with local context setting use_ocdbt=True + ctx2 = ocp.Context( + array_options=ocp.options.ArrayOptions( + saving=ocp.options.ArrayOptions.Saving(use_ocdbt=True) + ) + ) + with ctx2: + self.save_pytree(checkpointer, 1, self.pytree) + + pytree_dir_1 = self.directory / '1' / 'pytree' + self.assertTrue( + (pytree_dir_1 / 'manifest.ocdbt').exists(), + f'Expected manifest.ocdbt under {pytree_dir_1}', + ) + @parameterized.named_parameters( dict( testcase_name='true',