Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 4 additions & 20 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ class CheckpointManagerOptions:
deprecated, do not use. use `preservation_policy` instead.
keep_period:
deprecated, do not use. use `preservation_policy` instead.
should_keep_fn:
deprecated, do not use. use `preservation_policy` instead.
best_fn:
If set, maintains checkpoints based on the quality of given
metrics rather than recency. The function should accept a PyTree of metrics,
Expand Down Expand Up @@ -350,9 +348,9 @@ class CheckpointManagerOptions:
preservation_policy: An object used to determine which checkpoints to
preserve. If provided, overrides any other options dealing with this
subject, including `max_to_keep`, `keep_time_interval`, `keep_period`, and
`should_keep_fn`, `best_fn`, and is the sole means of determining which
checkpoints to preserve. If not provided, these other options are used
instead. Prefer to use this option over others.
`best_fn`, and is the sole means of determining which checkpoints to
preserve. If not provided, these other options are used instead. Prefer to
use this option over others.
prevent_write_metrics: False by default. If True, metrics will not be written.
enable_should_save_is_saving_in_progress_check: True by default. If False,
`should_save_fn` will not check `is_saving_in_progress`, and will assume
Expand All @@ -375,7 +373,6 @@ class CheckpointManagerOptions:
max_to_keep: Optional[int] = None
keep_time_interval: Optional[datetime.timedelta] = None
keep_period: Optional[int] = None
should_keep_fn: Optional[Callable[[int], bool]] = None
best_fn: Optional[Callable[[PyTree], float]] = None
best_mode: str = 'max'
keep_checkpoints_without_metrics: bool = True
Expand Down Expand Up @@ -464,10 +461,8 @@ def __post_init__(self):
)
if self.read_only and self.keep_period is not None:
self.keep_period = None
self.should_keep_fn = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting keep_period=None'
' and should_keep_fn=None.'
'CheckpointManagerOptions.read_only=True, setting keep_period=None.'
)
if self.read_only and self.create:
self.create = False
Expand Down Expand Up @@ -512,13 +507,6 @@ def __post_init__(self):
'CheckpointManagerOptions.read_only=True, setting'
' should_save_fn=None.'
)
if self.preservation_policy is None and self.should_keep_fn is not None:
logging.warning(
'CheckpointManagerOptions.should_keep_fn is set, setting'
' keep_period=None (was %s).',
self.keep_period,
)
self.keep_period = None
self.save_on_steps = frozenset(self.save_on_steps or ())


Expand Down Expand Up @@ -1976,10 +1964,6 @@ def _get_old_steps_to_remove(self) -> List[int]:
self._checkpoints, preservation_result
)
if not should_preserve
and (
self._options.should_keep_fn is None
or not self._options.should_keep_fn(checkpoint.step)
)
]

def _wait_for_checkpointers(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def test_side_effect_options_update_for_read_only(self, kwargs):
self.assertEmpty(options.save_on_steps)
self.assertIsNone(options.todelete_subdir)
self.assertIsNone(options.should_save_fn)
self.assertIsNone(options.should_keep_fn)

def test_replace_for_read_only(self):
options = ocp.CheckpointManagerOptions(
Expand All @@ -111,14 +110,6 @@ def test_replace_for_read_only(self):
updated_options = dataclasses.replace(options, step_prefix='prefix')
self.assertEmpty(updated_options.save_on_steps)

def test_replace_for_should_keep_fn(self):
options = ocp.CheckpointManagerOptions(
keep_period=1,
should_keep_fn=lambda step: True,
)
self.assertIsNone(options.keep_period)
self.assertIsNotNone(options.should_keep_fn)

@parameterized.named_parameters(
dict(
testcase_name='error_step_name_format_false',
Expand Down
Loading