From ce110631fd885b3ef8156cde9ff588e934f77236 Mon Sep 17 00:00:00 2001 From: Justin Pan Date: Mon, 27 Apr 2026 13:34:52 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 906519454 --- .../_src/serialization/jax_array_handlers.py | 212 +++++++++++++++--- 1 file changed, 177 insertions(+), 35 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index d7cf764f4..3d95244b3 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -49,8 +49,8 @@ from orbax.checkpoint._src.serialization import types from orbax.checkpoint._src.serialization import worker_memory_utils from orbax.checkpoint._src.tree import utils as tree_utils -import tensorstore as ts +import tensorstore as ts Pytree: TypeAlias = Any ArrayRestoreArgs = jax_array_restore_args.ArrayRestoreArgs @@ -200,26 +200,148 @@ async def _async_serialize_shardings( await sharding_metadata_txn.commit_async() -def _get_replica_slices( - arrays: Sequence[jax.Array], +def _get_replica_slices_with_global( + arr: jax.Array, + g_sharding: jax.sharding.Sharding, + g_shape: Tuple[int, ...], replica_id: int, use_replica_parallel: bool, min_slice_bytes_for_replica_parallel: int | None = None, max_replicas_for_replica_parallel: int | None = None, -) -> Sequence[replica_slices.ReplicaSlices]: - """Returns ReplicaSlices for arrays.""" - rslices_per_array = [ - replica_slices.get_replica_slices( +) -> replica_slices.ReplicaSlices: + """Returns ReplicaSlices for an array with global sharding and shape.""" + replica_count = replica_slices._sharding_num_replicas( # pylint: disable=protected-access + g_sharding, g_shape + ) + + shard0 = arr.addressable_shards[0] + axis = None + for axis_index, axis_size in enumerate(shard0.data.shape): + if axis_size % replica_count == 0: + axis = axis_index + break + + if replica_count > 1 and axis is not None and use_replica_parallel: + local_shape = tuple( + axis_size // (replica_count if axis_index == axis else 1) + for axis_index, axis_size in enumerate(shard0.data.shape) + ) + min_slice_bytes = min_slice_bytes_for_replica_parallel or 0 + if np.prod(local_shape) * arr.dtype.itemsize >= min_slice_bytes: + rslices = [] + for shard in arr.addressable_shards: + if shard.replica_id >= replica_count: + continue + size = local_shape[axis] + slize = shard.index[axis] + start = slize.start or 0 + + start_offset = shard.replica_id * size + end_offset = start_offset + size + new_slice = slice(start + start_offset, start + end_offset) + logging.info( + '[process=%d] _get_replica_slices: shard replica=%d, full' + ' shape=%s, full_array_bytes=%s, slice=%s,' + ' slice_target_bytes=%s', + multihost.process_index(), + shard.replica_id, + shard.data.shape, + shard.data.nbytes, + new_slice, + shard.data.nbytes // replica_count, + ) + + rslices.append( + replica_slices.ReplicaSlice( + index=shard.index[:axis] + + (new_slice,) + + shard.index[axis + 1 :], + unsliced_data=shard.data, + slice_args=replica_slices.SliceArgs( + start_offset, end_offset, axis + ), + ) + ) + return replica_slices.ReplicaSlices( + global_shape=g_shape, + local_shape=local_shape, + sharding=g_sharding, + dtype=arr.dtype, + is_on_host=False, + replica_slices=rslices, + ) + else: + return replica_slices.get_replica_slices( arr, replica_id, - use_replica_parallel, + False, min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, ) - for arr in arrays - ] + else: + logging.info( + '[process=%d] _get_replica_slices standard: replica_count=%s,' + ' use_replica_parallel=False, array_bytes=%s', + multihost.process_index(), + replica_count, + arr.nbytes if hasattr(arr, 'nbytes') else 0, + ) + return replica_slices.get_replica_slices( + arr, + replica_id, + False, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + ) + + +def _get_replica_slices( + arrays: Sequence[jax.Array], + replica_id: int, + use_replica_parallel: bool, + min_slice_bytes_for_replica_parallel: int | None = None, + max_replicas_for_replica_parallel: int | None = None, + global_shardings: Sequence[jax.sharding.Sharding] | None = None, + global_shapes: Sequence[Tuple[int, ...]] | None = None, +) -> Sequence[replica_slices.ReplicaSlices]: + """Returns ReplicaSlices for arrays.""" + rslices_per_array = [] + for i, arr in enumerate(arrays): + if global_shardings and global_shapes: + rslices_per_array.append( + _get_replica_slices_with_global( + arr, + global_shardings[i], + global_shapes[i], + replica_id, + use_replica_parallel, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + ) + ) + else: + rslices_per_array.append( + replica_slices.get_replica_slices( + arr, + replica_id, + use_replica_parallel, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + ) + ) + # D2H copy is performed automatically as part of dispatcher call, but # we must set properties correctly to pass later consistency checks. + def _get_sliced_host_data(rslice: replica_slices.ReplicaSlice) -> np.ndarray: + np_data = np.asarray(rslice.unsliced_data) + if rslice.slice_args is not None: + slicer = [slice(None)] * np_data.ndim + slicer[rslice.slice_args.axis] = slice( + rslice.slice_args.start_index, rslice.slice_args.limit_index + ) + np_data = np_data[tuple(slicer)] + return np_data + return [ dataclasses.replace( rslices, @@ -227,7 +349,7 @@ def _get_replica_slices( replica_slices=[ dataclasses.replace( rslice, - unsliced_data=np.asarray(rslice.data()), + unsliced_data=_get_sliced_host_data(rslice), slice_args=None, ) for rslice in rslices.replica_slices @@ -250,6 +372,8 @@ def _worker_serialize_arrays( array_metadata_store: array_metadata_store_lib.Store | None, enable_replica_parallel_separate_folder: bool, ext_metadata: Dict[str, Any], + global_shardings: Sequence[jax.sharding.Sharding] | None = None, + global_shapes: Sequence[Tuple[int, ...]] | None = None, ): """Worker function to serialize arrays.""" rslices_per_array = _get_replica_slices( @@ -258,6 +382,8 @@ def _worker_serialize_arrays( use_replica_parallel, min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, + global_shardings=global_shardings, + global_shapes=global_shapes, ) asyncio_utils.run_sync( @@ -487,30 +613,45 @@ def _serialize_batch( batch_args: Sequence[types.SaveArgs], batch_arrays: Sequence[jax.Array], ): - ret = dispatcher.dispatch( - _worker_serialize_arrays, - input_arrays=batch_arrays, - func_kwargs={ - 'infos': batch_infos, - 'args': batch_args, - 'replica_id': replica_id, - 'use_replica_parallel': use_replica_parallel, - 'min_slice_bytes_for_replica_parallel': ( - min_slice_bytes_for_replica_parallel - ), - 'max_replicas_for_replica_parallel': ( - max_replicas_for_replica_parallel - ), - 'primary_host': primary_host, - 'metadata_key': metadata_key, - 'array_metadata_store': array_metadata_store, - 'enable_replica_parallel_separate_folder': ( - enable_replica_parallel_separate_folder - ), - 'ext_metadata': ext_metadata, - }, - ) - jax.block_until_ready(ret) + groups = {} + for idx, arr in enumerate(batch_arrays): + d_set = frozenset(arr.devices()) + if d_set not in groups: + groups[d_set] = [] + groups[d_set].append(idx) + + rets = [] + for indices in groups.values(): + sub_arrays = [batch_arrays[i] for i in indices] + sub_infos = [batch_infos[i] for i in indices] + sub_args = [batch_args[i] for i in indices] + + rets.append(dispatcher.dispatch( + _worker_serialize_arrays, + input_arrays=sub_arrays, + func_kwargs={ + 'infos': sub_infos, + 'args': sub_args, + 'replica_id': replica_id, + 'use_replica_parallel': use_replica_parallel, + 'min_slice_bytes_for_replica_parallel': ( + min_slice_bytes_for_replica_parallel + ), + 'max_replicas_for_replica_parallel': ( + max_replicas_for_replica_parallel + ), + 'primary_host': primary_host, + 'metadata_key': metadata_key, + 'array_metadata_store': array_metadata_store, + 'enable_replica_parallel_separate_folder': ( + enable_replica_parallel_separate_folder + ), + 'ext_metadata': ext_metadata, + 'global_shardings': [batch_arrays[i].sharding for i in indices], + 'global_shapes': [batch_arrays[i].shape for i in indices], + }, + )) + jax.block_until_ready(rets) # Enqueue D2H operation for prioritized values. if prioritized: @@ -535,6 +676,7 @@ def _serialize_batch( ) all_infos = infos + async def _serialize(): for info in all_infos: await info.await_path_creation()