Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ def is_gcs_path(path: epath.Path) -> bool:
return path.as_posix().startswith(_GCS_PATH_PREFIX)


def to_gcsfuse_path(path: epath.PathLike) -> str:
"""Converts a GCS path to a gcsfuse path string.

GCSfuse paths start with /gcs/ and are accessible via File API when gcsfuse
is enabled.

Args:
path: A GCS path which can be a string or epath.Path.

Returns:
A gcsfuse path string starting with /gcs/.

Raises:
ValueError: If path is not a GCS path.
"""
path_str = str(path)
if path_str.startswith('gs://'):
return '/gcs/' + path_str[5:]
elif path_str.startswith('/gcs/'):
return path_str
else:
raise ValueError(f'Path is not a GCS path: {path}')


def parse_gcs_path(path: epath.PathLike) -> tuple[str, str]:
parsed = parse.urlparse(str(path))
assert parsed.scheme == 'gs', f'Unsupported scheme for GCS: {parsed.scheme}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats."""

import asyncio
from concurrent import futures
import dataclasses
import json
import os
import time
from typing import Any, Awaitable, cast

Expand All @@ -26,6 +28,7 @@
import numpy as np
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.tree import utils as tree_utils
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
Expand Down Expand Up @@ -438,8 +441,112 @@ async def _read_bundle(
bundle_start_offset = 0
return bundle_bytes, bundle_start_offset

async def load_single_host(self) -> dict[str, np.ndarray]:
async def load_single_host_gcsfuse(
self, gcs_path_str: str, abstract_pytree: dict[str, Any] | None
) -> dict[str, np.ndarray]:
"""Downloads tensors from Google Cloud Storage using high-bandwidth parallel reads.

This method uses `os.pread` with a thread pool to achieve high-bandwidth
parallel downloads from GCS via gcsfuse. It first calculates the bounding
box of the required tensor data and then reads chunks within that range.

Args:
gcs_path_str: The gcsfuse path to the safetensors file.
abstract_pytree: A flat dictionary mapping tensor names to
jax.ShapeDtypeStruct objects. Only tensors present in this dict will be
loaded.

Returns:
A dictionary mapping tensor names to loaded NumPy arrays.

Raises:
EOFError: If the file is truncated or reading fails unexpectedly.
ValueError: If non-finite values are found in a loaded tensor.
"""

header, data_start_offset = await self.read_header()
tensors = {}

min_start = float("inf")
max_end = 0

# 1. Calculate the exact Bounding Box of the batch
if abstract_pytree is None:
tensor_names = header.keys()
else:
tensor_names = abstract_pytree.keys()
for t_name in tensor_names:
if t_name == "__metadata__":
continue
if t_name not in header:
# Raise an error if the tensor is not found in the header.
raise ValueError(f"Tensor {t_name} not found in header.")
start, end = header[t_name]["data_offsets"]
if start < min_start:
min_start = start
if end > max_end:
max_end = end

offset = data_start_offset + min_start
length = max_end - min_start
chunks = []
bytes_read = 0
chunk_size = 1024 * 1024 * 1024
while bytes_read < length:
current_chunk_size = min(chunk_size, length - bytes_read)
current_offset = offset + bytes_read
chunks.append((current_chunk_size, current_offset))
bytes_read += current_chunk_size

def _read_single_chunk(chunk_data):
chunk_size, offset = chunk_data
with open(gcs_path_str, "rb") as f:
bytes_read = 0
chunk_pieces = []
while bytes_read < chunk_size:
piece = os.pread(
f.fileno(), chunk_size - bytes_read, offset + bytes_read
)
if not piece:
raise EOFError(
f"Unexpected end of file at offset {offset + bytes_read} "
f"in file {gcs_path_str}. Expected {chunk_size} bytes, "
f"got {bytes_read}."
)
chunk_pieces.append(piece)
bytes_read += len(piece)
return b"".join(chunk_pieces)

max_workers = 16
# 2. Execute the parallel reads
with futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
read_chunks = list(executor.map(_read_single_chunk, chunks))

data_bytes = b"".join(read_chunks)
data_mv = memoryview(data_bytes)
for name in tensor_names:
if name == "__metadata__":
continue
shape, dtype = _get_array_properties(header[name])
start_offset, end_offset = header[name]["data_offsets"]
tensor_bytes = data_mv[
start_offset - min_start : end_offset - min_start
]
np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
if not np.isfinite(np_array).all():
raise ValueError(f"Non-finite values found in tensor {name}.")
tensors[name] = np_array
return tensors

async def load_single_host(
self, abstract_pytree: dict[str, Any] | None
) -> dict[str, np.ndarray]:
"""Loads tensors from a safetensors file into host NumPy arrays."""
if gcs_utils.is_gcs_path(self.path):
gcs_path_str = gcs_utils.to_gcsfuse_path(self.path)
return await self.load_single_host_gcsfuse(gcs_path_str, abstract_pytree)
header, data_start_offset = await self.read_header()
tensors = {}
async with async_path.open_file(self.path, mode="rb") as f:
Expand Down Expand Up @@ -585,7 +692,7 @@ async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any:
start = time.time()
load_ops = []
for loader in await self._get_loaders():
load_ops.append(loader.load_single_host())
load_ops.append(loader.load_single_host(abstract_pytree))

restored_pytree = {}
for file_tensors in await asyncio.gather(*load_ops):
Expand Down
Loading