Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ merging.
- #v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree
leaves can also be saved as individual checkpointables.
- Move MTC files to multi_tier_checkpointing and use local checkpoint engine
- #v1 Allow a context to be default-configured for all `Checkpointer` operations.

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class Context(epy.ContextManager):
that thread will not inherit the context and will fall back to default
settings.

Note: When testing or mixing checkpointer instances and free functions,
explicitly wrap free functions inside their own `with ocp.Context(...)` block,
or pass explicit contexts to Checkpointer constructors, to ensure each actor
receives its correct active configuration independent of the surrounding
context.

Example:
Basic usage and explicit inheritance::

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,37 @@ def test_default_context(self):
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.array_options, ArrayOptions())

def test_get_context_with_default(self):
default_ctx = ocp.Context(
array_options=ArrayOptions(saving=ArrayOptions.Saving(use_ocdbt=False))
)
custom_ctx = ocp.Context(
array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False))
)

with self.subTest("no context set, no default provided"):
ctx = context_lib.get_context()
self.assertEqual(ctx.array_options, ArrayOptions())

with self.subTest("no context set, default provided"):
ctx = context_lib.get_context(default=default_ctx)
self.assertIs(ctx, default_ctx)

with self.subTest("context IS set, no default provided"):
with custom_ctx:
ctx = context_lib.get_context()
self.assertIs(ctx, custom_ctx)

with self.subTest("context IS set, default provided"):
with custom_ctx:
ctx = context_lib.get_context(default=default_ctx)
self.assertIs(ctx, custom_ctx)
self.assertIsNot(ctx, default_ctx)

with self.subTest("no context set, default=None provided"):
ctx = context_lib.get_context(default=None)
self.assertEqual(ctx.array_options, ArrayOptions())

def test_custom_context(self):
with ocp.Context(
array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def get_v0_checkpointer_and_args(
checkpointables: dict[str, Any],
*,
metrics: tree_types.JsonType | None = None,
context: context_lib.Context,
) -> tuple[
async_checkpointer.AsyncCheckpointer,
composite_checkpoint_handler.CompositeArgs,
Expand All @@ -309,11 +308,11 @@ def get_v0_checkpointer_and_args(
Args:
checkpointables: A dictionary of checkpointables.
metrics: Optional metrics to add to the checkpointables.
context: The Orbax context.

Returns:
A tuple containing the V0 Checkpointer and Args.
"""
context = context_lib.get_context()
checkpointables = execution.add_internal_checkpointables(
checkpointables, context=context, metrics=metrics
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
self,
directory: path_types.PathLike,
*,
context: context_lib.Context | None = None,
save_decision_policy: (
save_decision_policies.SaveDecisionPolicy | None
) = None,
Expand Down Expand Up @@ -146,6 +147,8 @@ def __init__(
Args:
directory: The root directory where checkpoints are stored. The directory
will be created if it does not exist.
context: A :py:class:`~orbax.checkpoint.v1.Context` object that will be
used to wrap all function calls for this `Checkpointer`.
save_decision_policy: A policy used to determine when a checkpoint should
be saved. If not provided, the `Checkpointer` saves as often as possible
by default (assuming no checkpoint is currently being saved), and saves
Expand All @@ -167,7 +170,7 @@ def __init__(
checkpoint steps present and checkpoint info properties like `time` and
`metrics` are not needed.
"""
context = context_lib.get_context()
self._context = context or context_lib.get_context()

default_save_decision_policy = save_decision_policies.AnySavePolicy([
save_decision_policies.InitialSavePolicy(),
Expand All @@ -188,11 +191,11 @@ def __init__(
cleanup_tmp_directories=cleanup_tmp_directories,
lightweight_initialize=lightweight_initialize,
max_to_keep=None, # Unlimited.
todelete_full_path=context.deletion_options.gcs_deletion_options.todelete_full_path,
async_options=context.async_options.v0(),
file_options=context.file_options.v0(),
multiprocessing_options=context.multiprocessing_options.v0(),
temporary_path_class=context.file_options.temporary_path_class,
todelete_full_path=self._context.deletion_options.gcs_deletion_options.todelete_full_path,
async_options=self._context.async_options.v0(),
file_options=self._context.file_options.v0(),
multiprocessing_options=self._context.multiprocessing_options.v0(),
temporary_path_class=self._context.file_options.temporary_path_class,
# Prevent the checkpoint manager from writing metrics on its own. This
# class will take responsibility for writing metrics.
prevent_write_metrics=True,
Expand Down Expand Up @@ -242,7 +245,6 @@ class that `Checkpointer` is unaware of. Note that doing this is
Returns:
A list of checkpoints, sorted ascending by step.
"""

infos = sorted(self._manager._checkpoints, key=lambda info: info.step) # pylint: disable=protected-access
return [
CheckpointMetadata[None](
Expand Down Expand Up @@ -273,8 +275,9 @@ def _resolve_existing_checkpoint(

def should_save(self, step: int) -> bool:
"""Returns whether a checkpoint should be saved at the given step."""
step = _resolve_integer_step(step)
return self._manager.should_save(step)
with context_lib.get_context(self._context):
step = _resolve_integer_step(step)
return self._manager.should_save(step)

def save_pytree(
self,
Expand Down Expand Up @@ -550,6 +553,7 @@ def save_checkpointables_async(
StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the
target `step` already exists.
"""
context = context_lib.get_context(self._context)
validation.validate_save_checkpointables(checkpointables)
if overwrite:
logging.info(
Expand All @@ -561,21 +565,22 @@ def save_checkpointables_async(
self._manager.delete(step)
except FileNotFoundError:
pass
elif step in [c.step for c in self.checkpoints]:
elif any(c.step == step for c in self.checkpoints):
raise errors.StepAlreadyExistsError(f'Step {step} already exists.')

checkpointer, args = saving.get_v0_checkpointer_and_args(
checkpointables, metrics=metrics, context=context_lib.get_context()
)
self._manager._checkpointer = checkpointer # pylint: disable=protected-access
saved = self._manager.save(
step,
args=args,
metrics=metrics,
force=force,
custom_metadata=custom_metadata,
)
return _AsyncSaveResponse(self._manager, saved)
with context:
checkpointer, args = saving.get_v0_checkpointer_and_args(
checkpointables, metrics=metrics
)
self._manager._checkpointer = checkpointer # pylint: disable=protected-access
saved = self._manager.save(
step,
args=args,
metrics=metrics,
force=force,
custom_metadata=custom_metadata,
)
return _AsyncSaveResponse(self._manager, saved)

def load_pytree(
self,
Expand Down Expand Up @@ -752,11 +757,12 @@ def load_checkpointables(
returns only the keys specified in that dict, otherwise returns all
keys saved with `save_checkpointables`.
"""
step = self._resolve_existing_checkpoint(step).step
return loading.load_checkpointables(
self.directory / self._step_name_format.build_name(step),
abstract_checkpointables,
)
with context_lib.get_context(self._context):
step = self._resolve_existing_checkpoint(step).step
return loading.load_checkpointables(
self.directory / self._step_name_format.build_name(step),
abstract_checkpointables,
)

def load_pytree_async(
self,
Expand Down Expand Up @@ -797,23 +803,24 @@ def pytree_metadata(
:py:class:`.PyTreeMetadata`, along with checkpoint timestamp and metrics
information.
"""
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.pytree_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[
metadata_types.PyTreeMetadata
](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)
with context_lib.get_context(self._context):
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.pytree_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[
metadata_types.PyTreeMetadata
](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)

def checkpointables_metadata(
self, step: int | CheckpointMetadata | None = None
Expand All @@ -834,29 +841,31 @@ def checkpointables_metadata(
describing the checkpointables, along with checkpoint timestamp and
metrics information.
"""
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.checkpointables_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[dict[str, Any]](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)
with context_lib.get_context(self._context):
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.checkpointables_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[dict[str, Any]](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)

def root_metadata(
self,
) -> training_metadata_types.RootMetadata:
metadata = self._manager.metadata(None)
return RootMetadata(
directory=self.directory, custom_metadata=metadata.custom_metadata
)
with context_lib.get_context(self._context):
metadata = self._manager.metadata(None)
return RootMetadata(
directory=self.directory, custom_metadata=metadata.custom_metadata
)

def reload(self):
"""Reloads internal properties from the root directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def test_save_restore_pytree(self):
test_utils.assert_tree_equal(self, self.pytree, loaded)

with self.subTest('without_abstract_pytree'):
if multihost.is_pathways_backend():
self.skipTest('Must provide abstract_pytree for Pathways.')
loaded = checkpointer.load_pytree(0)
test_utils.assert_tree_equal(self, self.pytree, loaded)

Expand Down Expand Up @@ -435,6 +437,8 @@ def test_custom_checkpointables(self):
self.save_checkpointables(checkpointer, 0, checkpointables)

with self.subTest('load'):
if multihost.is_pathways_backend():
self.skipTest('Sharding metadata not present in Pathways.')
loaded = checkpointer.load_checkpointables(0)
self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar'])
test_utils.assert_tree_equal(
Expand All @@ -443,7 +447,16 @@ def test_custom_checkpointables(self):
self.assertEqual(checkpointables['foo'], loaded['foo'])
self.assertEqual(checkpointables['bar'], loaded['bar'])
with self.subTest('load_with_free_function'):
loaded = ocp.load_checkpointables(self.directory / '0')
if multihost.is_pathways_backend():
self.skipTest('Sharding metadata not present in Pathways.')
checkpointables_options = (
ocp.options.CheckpointablesOptions.create_with_handlers(
foo=handler_utils.FooHandler,
bar=handler_utils.BarHandler,
)
)
with ocp.Context(checkpointables_options=checkpointables_options):
loaded = ocp.load_checkpointables(self.directory / '0')
self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar'])
test_utils.assert_tree_equal(
self, checkpointables['pytree'], loaded['pytree']
Expand All @@ -462,6 +475,8 @@ def test_custom_checkpointables(self):
self.assertEqual(checkpointables['foo'], loaded['foo'])
self.assertEqual(checkpointables['bar'], loaded['bar'])
with self.subTest('load_with_abstract_checkpointables_none_values'):
if multihost.is_pathways_backend():
self.skipTest('Sharding metadata not present in Pathways.')
abstract_checkpointables = {
'pytree': None,
'foo': None,
Expand Down Expand Up @@ -683,6 +698,54 @@ def test_gcs_deletion_options(self):
)


def test_context_constructor_override(self):
ctx1 = ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(use_ocdbt=False)
),
pytree_options=ocp.options.PyTreeOptions(
loading=ocp.options.PyTreeOptions.Loading(partial_load=True)
),
)
checkpointer = Checkpointer(self.directory, context=ctx1)
self.enter_context(checkpointer)
self.save_pytree(checkpointer, 0, self.pytree)

with self.subTest('constructor_override_ocdbt'):
# Default use_ocdbt is True, so set to False to prove constructor arg is
# used.
pytree_dir = self.directory / '0' / 'pytree'
self.assertFalse(
(pytree_dir / 'manifest.ocdbt').exists(),
f'Expected NO manifest.ocdbt under {pytree_dir}',
)

with self.subTest('constructor_override_partial_load'):
loaded = checkpointer.load_pytree(0, self.abstract_pytree)
test_utils.assert_tree_equal(self, self.pytree, loaded)

# Test partial load override.
partial_abstract = {'jax_array': self.abstract_pytree['jax_array']}
loaded_partial = checkpointer.load_pytree(0, partial_abstract)
expected_pytree = {'jax_array': self.pytree['jax_array']}
test_utils.assert_tree_equal(self, expected_pytree, loaded_partial)

with self.subTest('local_context_override'):
# Override with local context setting use_ocdbt=True
ctx2 = ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(use_ocdbt=True)
)
)
with ctx2:
self.save_pytree(checkpointer, 1, self.pytree)

pytree_dir_1 = self.directory / '1' / 'pytree'
self.assertTrue(
(pytree_dir_1 / 'manifest.ocdbt').exists(),
f'Expected manifest.ocdbt under {pytree_dir_1}',
)

@parameterized.named_parameters(
dict(
testcase_name='true',
Expand Down
Loading