Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/async_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ async def unlink(path: epath.Path, missing_ok: bool = False):
return await asyncio.to_thread(path.unlink, missing_ok=missing_ok)


async def is_absolute(path: epath.Path):
return await asyncio.to_thread(path.is_absolute)


class AsyncFile:
"""Async wrapper for file operations."""

Expand Down
55 changes: 25 additions & 30 deletions checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from absl import logging
from etils import epath
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import utils as ocp_path_utils


Expand Down Expand Up @@ -69,11 +70,11 @@ def __init__(

async def create_snapshot(self) -> None:
"""Creates a deep copy of the checkpoint."""
if not await asyncio.to_thread(self._snapshot.is_absolute):
if not await async_path.is_absolute(self._snapshot):
raise ValueError(
f"Snapshot destination must be absolute, but was '{self._snapshot}'."
)
if not await asyncio.to_thread(self._source.exists):
if not await async_path.exists(self._source):
raise ValueError(f"Snapshot source does not exist: {self._source}'.")

t = ocp_path_utils.Timer()
Expand All @@ -89,10 +90,10 @@ async def create_snapshot(self) -> None:

async def release_snapshot(self) -> None:
"""Deletes a snapshot of the checkpoint."""
if not await asyncio.to_thread(self._snapshot.exists):
if not await async_path.exists(self._snapshot):
raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}")

await asyncio.to_thread(self._snapshot.rmtree)
await async_path.rmtree(self._snapshot)

# TODO(b/434025182): Handle recovery path upon restart.
async def replace_source(self) -> None:
Expand All @@ -110,12 +111,9 @@ async def replace_source(self) -> None:
self._source.parent / f"{self._source.name}._recovery_{time.time()}"
)

def _swap_source_and_snapshot():
self._source.rename(recovery_path)
self._snapshot.rename(self._source)
recovery_path.rmtree()

await asyncio.to_thread(_swap_source_and_snapshot)
await async_path.rename(self._source, recovery_path)
await async_path.rename(self._snapshot, self._source)
await async_path.rmtree(recovery_path)



Expand All @@ -136,16 +134,16 @@ def __init__(
self._snapshot = epath.Path(dst)

async def create_snapshot(self) -> None:
if not await asyncio.to_thread(self._snapshot.is_absolute):
if not await async_path.is_absolute(self._snapshot):
raise ValueError(
f"Snapshot destination must be absolute, but was '{self._snapshot}'."
)
await asyncio.to_thread(self._snapshot.mkdir, parents=True, exist_ok=True)
await async_path.mkdir(self._snapshot, parents=True, exist_ok=True)

async def release_snapshot(self) -> None:
if not await asyncio.to_thread(self._snapshot.exists):
if not await async_path.exists(self._snapshot):
return
await asyncio.to_thread(self._snapshot.rmtree)
await async_path.rmtree(self._snapshot)

async def replace_source(self) -> None:
if not self._snapshot.is_absolute():
Expand All @@ -157,22 +155,19 @@ async def replace_source(self) -> None:
f"Snapshot source must be absolute, but was '{self._source}'."
)

def _move_items_into_source():
if not self._snapshot.exists():
raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}")
if not self._source.exists():
self._source.mkdir(parents=True, exist_ok=True)
# Move files from inside the tmp snapshot into the original source
# directory under a pending suffix. This is to avoid potentially wiping
# out previous files.
# Partial saving relies on this behavior to accumulate `pending_*`
# directories in the shared parent path before merging them sequentially
# upon finalization.
pending_suffix = f"{PENDING_DIR_SUFFIX}{uuid.uuid4()}"
dst_path = self._source / f"{self._source.name}{pending_suffix}"
self._snapshot.rename(dst_path)

await asyncio.to_thread(_move_items_into_source)
if not await async_path.exists(self._snapshot):
raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}")
if not await async_path.exists(self._source):
await async_path.mkdir(self._source, parents=True, exist_ok=True)
# Move files from inside the tmp snapshot into the original source
# directory under a pending suffix. This is to avoid potentially wiping
# out previous files.
# Partial saving relies on this behavior to accumulate `pending_*`
# directories in the shared parent path before merging them sequentially
# upon finalization.
pending_suffix = f"{PENDING_DIR_SUFFIX}{uuid.uuid4()}"
dst_path = self._source / f"{self._source.name}{pending_suffix}"
await async_path.rename(self._snapshot, dst_path)


def create_instance(
Expand Down
Loading