diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 8929c6614..0b3647b5e 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -265,8 +265,6 @@ class CheckpointManagerOptions: deprecated, do not use. use `preservation_policy` instead. keep_period: deprecated, do not use. use `preservation_policy` instead. - should_keep_fn: - deprecated, do not use. use `preservation_policy` instead. best_fn: If set, maintains checkpoints based on the quality of given metrics rather than recency. The function should accept a PyTree of metrics, @@ -350,9 +348,9 @@ class CheckpointManagerOptions: preservation_policy: An object used to determine which checkpoints to preserve. If provided, overrides any other options dealing with this subject, including `max_to_keep`, `keep_time_interval`, `keep_period`, and - `should_keep_fn`, `best_fn`, and is the sole means of determining which - checkpoints to preserve. If not provided, these other options are used - instead. Prefer to use this option over others. + `best_fn`, and is the sole means of determining which checkpoints to + preserve. If not provided, these other options are used instead. Prefer to + use this option over others. prevent_write_metrics: False by default. If True, metrics will not be written. enable_should_save_is_saving_in_progress_check: True by default. If False, `should_save_fn` will not check `is_saving_in_progress`, and will assume @@ -375,7 +373,6 @@ class CheckpointManagerOptions: max_to_keep: Optional[int] = None keep_time_interval: Optional[datetime.timedelta] = None keep_period: Optional[int] = None - should_keep_fn: Optional[Callable[[int], bool]] = None best_fn: Optional[Callable[[PyTree], float]] = None best_mode: str = 'max' keep_checkpoints_without_metrics: bool = True @@ -464,10 +461,8 @@ def __post_init__(self): ) if self.read_only and self.keep_period is not None: self.keep_period = None - self.should_keep_fn = None logging.warning( - 'CheckpointManagerOptions.read_only=True, setting keep_period=None' - ' and should_keep_fn=None.' + 'CheckpointManagerOptions.read_only=True, setting keep_period=None.' ) if self.read_only and self.create: self.create = False @@ -512,13 +507,6 @@ def __post_init__(self): 'CheckpointManagerOptions.read_only=True, setting' ' should_save_fn=None.' ) - if self.preservation_policy is None and self.should_keep_fn is not None: - logging.warning( - 'CheckpointManagerOptions.should_keep_fn is set, setting' - ' keep_period=None (was %s).', - self.keep_period, - ) - self.keep_period = None self.save_on_steps = frozenset(self.save_on_steps or ()) @@ -1976,10 +1964,6 @@ def _get_old_steps_to_remove(self) -> List[int]: self._checkpoints, preservation_result ) if not should_preserve - and ( - self._options.should_keep_fn is None - or not self._options.should_keep_fn(checkpoint.step) - ) ] def _wait_for_checkpointers(self): diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager_options_test.py b/checkpoint/orbax/checkpoint/checkpoint_manager_options_test.py index b5f6f8a6f..5404f7177 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager_options_test.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager_options_test.py @@ -101,7 +101,6 @@ def test_side_effect_options_update_for_read_only(self, kwargs): self.assertEmpty(options.save_on_steps) self.assertIsNone(options.todelete_subdir) self.assertIsNone(options.should_save_fn) - self.assertIsNone(options.should_keep_fn) def test_replace_for_read_only(self): options = ocp.CheckpointManagerOptions( @@ -111,14 +110,6 @@ def test_replace_for_read_only(self): updated_options = dataclasses.replace(options, step_prefix='prefix') self.assertEmpty(updated_options.save_on_steps) - def test_replace_for_should_keep_fn(self): - options = ocp.CheckpointManagerOptions( - keep_period=1, - should_keep_fn=lambda step: True, - ) - self.assertIsNone(options.keep_period) - self.assertIsNotNone(options.should_keep_fn) - @parameterized.named_parameters( dict( testcase_name='error_step_name_format_false',