From 7fa48781ad9cb10a0c8f564f6bf6bb9737013c08 Mon Sep 17 00:00:00 2001 From: Abhishek Agrawal Date: Tue, 28 Apr 2026 02:20:13 -0700 Subject: [PATCH] Add optimized GCS loading for Safetensors using gcsfuse. PiperOrigin-RevId: 906830773 --- .../orbax/checkpoint/_src/path/gcs_utils.py | 24 ++++ .../v1/_src/layout/safetensors_layout.py | 111 +++++++++++++++++- 2 files changed, 133 insertions(+), 2 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py index f023dffa1..df1f5f50d 100644 --- a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py @@ -28,6 +28,30 @@ def is_gcs_path(path: epath.Path) -> bool: return path.as_posix().startswith(_GCS_PATH_PREFIX) +def to_gcsfuse_path(path: epath.PathLike) -> str: + """Converts a GCS path to a gcsfuse path string. + + GCSfuse paths start with /gcs/ and are accessible via File API when gcsfuse + is enabled. + + Args: + path: A GCS path which can be a string or epath.Path. + + Returns: + A gcsfuse path string starting with /gcs/. + + Raises: + ValueError: If path is not a GCS path. + """ + path_str = str(path) + if path_str.startswith('gs://'): + return '/gcs/' + path_str[5:] + elif path_str.startswith('/gcs/'): + return path_str + else: + raise ValueError(f'Path is not a GCS path: {path}') + + def parse_gcs_path(path: epath.PathLike) -> tuple[str, str]: parsed = parse.urlparse(str(path)) assert parsed.scheme == 'gs', f'Unsupported scheme for GCS: {parsed.scheme}' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index cbaa5ce02..a2b11d3e1 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -15,8 +15,10 @@ """Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats.""" import asyncio +from concurrent import futures import dataclasses import json +import os import time from typing import Any, Awaitable, cast @@ -26,6 +28,7 @@ import numpy as np from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import gcs_utils from orbax.checkpoint._src.tree import utils as tree_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout @@ -438,8 +441,112 @@ async def _read_bundle( bundle_start_offset = 0 return bundle_bytes, bundle_start_offset - async def load_single_host(self) -> dict[str, np.ndarray]: + async def load_single_host_gcsfuse( + self, gcs_path_str: str, abstract_pytree: dict[str, Any] | None + ) -> dict[str, np.ndarray]: + """Downloads tensors from Google Cloud Storage using high-bandwidth parallel reads. + + This method uses `os.pread` with a thread pool to achieve high-bandwidth + parallel downloads from GCS via gcsfuse. It first calculates the bounding + box of the required tensor data and then reads chunks within that range. + + Args: + gcs_path_str: The gcsfuse path to the safetensors file. + abstract_pytree: A flat dictionary mapping tensor names to + jax.ShapeDtypeStruct objects. Only tensors present in this dict will be + loaded. + + Returns: + A dictionary mapping tensor names to loaded NumPy arrays. + + Raises: + EOFError: If the file is truncated or reading fails unexpectedly. + ValueError: If non-finite values are found in a loaded tensor. + """ + + header, data_start_offset = await self.read_header() + tensors = {} + + min_start = float("inf") + max_end = 0 + + # 1. Calculate the exact Bounding Box of the batch + if abstract_pytree is None: + tensor_names = header.keys() + else: + tensor_names = abstract_pytree.keys() + for t_name in tensor_names: + if t_name == "__metadata__": + continue + if t_name not in header: + # Raise an error if the tensor is not found in the header. + raise ValueError(f"Tensor {t_name} not found in header.") + start, end = header[t_name]["data_offsets"] + if start < min_start: + min_start = start + if end > max_end: + max_end = end + + offset = data_start_offset + min_start + length = max_end - min_start + chunks = [] + bytes_read = 0 + chunk_size = 1024 * 1024 * 1024 + while bytes_read < length: + current_chunk_size = min(chunk_size, length - bytes_read) + current_offset = offset + bytes_read + chunks.append((current_chunk_size, current_offset)) + bytes_read += current_chunk_size + + def _read_single_chunk(chunk_data): + chunk_size, offset = chunk_data + with open(gcs_path_str, "rb") as f: + bytes_read = 0 + chunk_pieces = [] + while bytes_read < chunk_size: + piece = os.pread( + f.fileno(), chunk_size - bytes_read, offset + bytes_read + ) + if not piece: + raise EOFError( + f"Unexpected end of file at offset {offset + bytes_read} " + f"in file {gcs_path_str}. Expected {chunk_size} bytes, " + f"got {bytes_read}." + ) + chunk_pieces.append(piece) + bytes_read += len(piece) + return b"".join(chunk_pieces) + + max_workers = 16 + # 2. Execute the parallel reads + with futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + read_chunks = list(executor.map(_read_single_chunk, chunks)) + + data_bytes = b"".join(read_chunks) + data_mv = memoryview(data_bytes) + for name in tensor_names: + if name == "__metadata__": + continue + shape, dtype = _get_array_properties(header[name]) + start_offset, end_offset = header[name]["data_offsets"] + tensor_bytes = data_mv[ + start_offset - min_start : end_offset - min_start + ] + np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape) + if not np.isfinite(np_array).all(): + raise ValueError(f"Non-finite values found in tensor {name}.") + tensors[name] = np_array + return tensors + + async def load_single_host( + self, abstract_pytree: dict[str, Any] | None + ) -> dict[str, np.ndarray]: """Loads tensors from a safetensors file into host NumPy arrays.""" + if gcs_utils.is_gcs_path(self.path): + gcs_path_str = gcs_utils.to_gcsfuse_path(self.path) + return await self.load_single_host_gcsfuse(gcs_path_str, abstract_pytree) header, data_start_offset = await self.read_header() tensors = {} async with async_path.open_file(self.path, mode="rb") as f: @@ -585,7 +692,7 @@ async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any: start = time.time() load_ops = [] for loader in await self._get_loaders(): - load_ops.append(loader.load_single_host()) + load_ops.append(loader.load_single_host(abstract_pytree)) restored_pytree = {} for file_tensors in await asyncio.gather(*load_ops):