diff --git a/checkpoint/orbax/checkpoint/options.py b/checkpoint/orbax/checkpoint/options.py index a3f636319..39e03d3a5 100644 --- a/checkpoint/orbax/checkpoint/options.py +++ b/checkpoint/orbax/checkpoint/options.py @@ -83,6 +83,12 @@ class MemoryLimitOptions: Attributes: max_transfer_concurrent_gb: The max memory limit in GB allowed for. Required if `save_device_host_concurrent_gb` is set to `"auto"`. + fallback_host_limit_gb: The fallback physical machine size in GB to use if + the profiler fails to fetch the total memory dynamically. + surge_fn: A function that takes the current step and returns the expected + memory surge in GB. """ max_transfer_concurrent_gb: int | None = None + fallback_host_limit_gb: int | None = None + surge_fn: Callable[[int], float] | None = None