diff --git a/checkpoint/orbax/checkpoint/_src/path/async_path.py b/checkpoint/orbax/checkpoint/_src/path/async_path.py index 7ba488e23..090995190 100644 --- a/checkpoint/orbax/checkpoint/_src/path/async_path.py +++ b/checkpoint/orbax/checkpoint/_src/path/async_path.py @@ -134,6 +134,10 @@ async def unlink(path: epath.Path, missing_ok: bool = False): return await asyncio.to_thread(path.unlink, missing_ok=missing_ok) +async def is_absolute(path: epath.Path): + return await asyncio.to_thread(path.is_absolute) + + class AsyncFile: """Async wrapper for file operations.""" diff --git a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py index 1add06df7..64b91cf3a 100644 --- a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py +++ b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py @@ -22,6 +22,7 @@ from absl import logging from etils import epath +from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import utils as ocp_path_utils @@ -69,11 +70,11 @@ def __init__( async def create_snapshot(self) -> None: """Creates a deep copy of the checkpoint.""" - if not await asyncio.to_thread(self._snapshot.is_absolute): + if not await async_path.is_absolute(self._snapshot): raise ValueError( f"Snapshot destination must be absolute, but was '{self._snapshot}'." ) - if not await asyncio.to_thread(self._source.exists): + if not await async_path.exists(self._source): raise ValueError(f"Snapshot source does not exist: {self._source}'.") t = ocp_path_utils.Timer() @@ -89,10 +90,10 @@ async def create_snapshot(self) -> None: async def release_snapshot(self) -> None: """Deletes a snapshot of the checkpoint.""" - if not await asyncio.to_thread(self._snapshot.exists): + if not await async_path.exists(self._snapshot): raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}") - await asyncio.to_thread(self._snapshot.rmtree) + await async_path.rmtree(self._snapshot) # TODO(b/434025182): Handle recovery path upon restart. async def replace_source(self) -> None: @@ -110,12 +111,9 @@ async def replace_source(self) -> None: self._source.parent / f"{self._source.name}._recovery_{time.time()}" ) - def _swap_source_and_snapshot(): - self._source.rename(recovery_path) - self._snapshot.rename(self._source) - recovery_path.rmtree() - - await asyncio.to_thread(_swap_source_and_snapshot) + await async_path.rename(self._source, recovery_path) + await async_path.rename(self._snapshot, self._source) + await async_path.rmtree(recovery_path) @@ -136,16 +134,16 @@ def __init__( self._snapshot = epath.Path(dst) async def create_snapshot(self) -> None: - if not await asyncio.to_thread(self._snapshot.is_absolute): + if not await async_path.is_absolute(self._snapshot): raise ValueError( f"Snapshot destination must be absolute, but was '{self._snapshot}'." ) - await asyncio.to_thread(self._snapshot.mkdir, parents=True, exist_ok=True) + await async_path.mkdir(self._snapshot, parents=True, exist_ok=True) async def release_snapshot(self) -> None: - if not await asyncio.to_thread(self._snapshot.exists): + if not await async_path.exists(self._snapshot): return - await asyncio.to_thread(self._snapshot.rmtree) + await async_path.rmtree(self._snapshot) async def replace_source(self) -> None: if not self._snapshot.is_absolute(): @@ -157,22 +155,19 @@ async def replace_source(self) -> None: f"Snapshot source must be absolute, but was '{self._source}'." ) - def _move_items_into_source(): - if not self._snapshot.exists(): - raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}") - if not self._source.exists(): - self._source.mkdir(parents=True, exist_ok=True) - # Move files from inside the tmp snapshot into the original source - # directory under a pending suffix. This is to avoid potentially wiping - # out previous files. - # Partial saving relies on this behavior to accumulate `pending_*` - # directories in the shared parent path before merging them sequentially - # upon finalization. - pending_suffix = f"{PENDING_DIR_SUFFIX}{uuid.uuid4()}" - dst_path = self._source / f"{self._source.name}{pending_suffix}" - self._snapshot.rename(dst_path) - - await asyncio.to_thread(_move_items_into_source) + if not await async_path.exists(self._snapshot): + raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}") + if not await async_path.exists(self._source): + await async_path.mkdir(self._source, parents=True, exist_ok=True) + # Move files from inside the tmp snapshot into the original source + # directory under a pending suffix. This is to avoid potentially wiping + # out previous files. + # Partial saving relies on this behavior to accumulate `pending_*` + # directories in the shared parent path before merging them sequentially + # upon finalization. + pending_suffix = f"{PENDING_DIR_SUFFIX}{uuid.uuid4()}" + dst_path = self._source / f"{self._source.name}{pending_suffix}" + await async_path.rename(self._snapshot, dst_path) def create_instance(