From 2d5803bf2ba38b294e4e3737c7d0a9118b396e27 Mon Sep 17 00:00:00 2001 From: Justin Pan Date: Tue, 21 Apr 2026 13:57:43 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 903411066 --- .../_src/serialization/jax_array_handlers.py | 161 +++++++++++++----- .../_src/serialization/replica_slices.py | 55 +++++- .../replica_slices_multiprocess_test.py | 76 +++++++++ .../_src/serialization/replica_slices_test.py | 110 +++++++++++- .../_src/serialization/worker_memory_utils.py | 7 +- 5 files changed, 360 insertions(+), 49 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/serialization/replica_slices_multiprocess_test.py diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 7b75807ee..b9e385c39 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -49,8 +49,8 @@ from orbax.checkpoint._src.serialization import types from orbax.checkpoint._src.serialization import worker_memory_utils from orbax.checkpoint._src.tree import utils as tree_utils -import tensorstore as ts +import tensorstore as ts Pytree: TypeAlias = Any ArrayRestoreArgs = jax_array_restore_args.ArrayRestoreArgs @@ -188,18 +188,22 @@ def _get_replica_slices( use_replica_parallel: bool, min_slice_bytes_for_replica_parallel: int | None = None, max_replicas_for_replica_parallel: int | None = None, + precomputed_rslices: Sequence[replica_slices.ReplicaSlices] | None = None, ) -> Sequence[replica_slices.ReplicaSlices]: """Returns ReplicaSlices for arrays.""" - rslices_per_array = [ - replica_slices.get_replica_slices( - arr, - replica_id, - use_replica_parallel, - min_slice_bytes_for_replica_parallel, - max_replicas_for_replica_parallel, - ) - for arr in arrays - ] + if precomputed_rslices is not None: + rslices_per_array = precomputed_rslices + else: + rslices_per_array = [ + replica_slices.get_replica_slices( + arr, + replica_id, + use_replica_parallel, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + ) + for arr in arrays + ] # D2H copy is performed automatically as part of dispatcher call, but # we must set properties correctly to pass later consistency checks. return [ @@ -232,6 +236,7 @@ def _worker_serialize_arrays( array_metadata_store: array_metadata_store_lib.Store | None, enable_replica_parallel_separate_folder: bool, ext_metadata: Dict[str, Any], + precomputed_rslices: Sequence[replica_slices.ReplicaSlices] | None = None, ): """Worker function to serialize arrays.""" rslices_per_array = _get_replica_slices( @@ -240,6 +245,12 @@ def _worker_serialize_arrays( use_replica_parallel, min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, + precomputed_rslices=precomputed_rslices, + ) + logging.info( + 'use_replica_parallel: %s, rslices_per_array: %s', + use_replica_parallel, + rslices_per_array, ) asyncio_utils.run_sync( @@ -468,33 +479,55 @@ def _serialize_batch( batch_infos: Sequence[types.ParamInfo], batch_args: Sequence[types.SaveArgs], batch_arrays: Sequence[jax.Array], + precomputed_rslices: ( + Sequence[replica_slices.ReplicaSlices] | None + ) = None, ): - ret = dispatcher.dispatch( - _worker_serialize_arrays, - input_arrays=batch_arrays, - func_kwargs={ - 'infos': batch_infos, - 'args': batch_args, - 'replica_id': replica_id, - 'use_replica_parallel': use_replica_parallel, - 'min_slice_bytes_for_replica_parallel': ( - min_slice_bytes_for_replica_parallel - ), - 'max_replicas_for_replica_parallel': ( - max_replicas_for_replica_parallel - ), - 'primary_host': primary_host, - 'metadata_key': metadata_key, - 'array_metadata_store': array_metadata_store, - 'enable_replica_parallel_separate_folder': ( - enable_replica_parallel_separate_folder - ), - 'ext_metadata': ext_metadata, - }, - ) - jax.block_until_ready(ret) + groups = {} + for idx, arr in enumerate(batch_arrays): + d_set = frozenset(arr.devices()) + if d_set not in groups: + groups[d_set] = [] + groups[d_set].append(idx) + + for indices in groups.values(): + sub_arrays = [batch_arrays[i] for i in indices] + sub_infos = [batch_infos[i] for i in indices] + sub_args = [batch_args[i] for i in indices] + sub_rslices = ( + [precomputed_rslices[i] for i in indices] + if precomputed_rslices + else None + ) + + ret = dispatcher.dispatch( + _worker_serialize_arrays, + input_arrays=sub_arrays, + func_kwargs={ + 'infos': sub_infos, + 'args': sub_args, + 'replica_id': replica_id, + 'use_replica_parallel': use_replica_parallel, + 'min_slice_bytes_for_replica_parallel': ( + min_slice_bytes_for_replica_parallel + ), + 'max_replicas_for_replica_parallel': ( + max_replicas_for_replica_parallel + ), + 'primary_host': primary_host, + 'metadata_key': metadata_key, + 'array_metadata_store': array_metadata_store, + 'enable_replica_parallel_separate_folder': ( + enable_replica_parallel_separate_folder + ), + 'ext_metadata': ext_metadata, + 'precomputed_rslices': sub_rslices, + }, + ) + jax.block_until_ready(ret) # Enqueue D2H operation for prioritized values. + prioritized_rslices = None if prioritized: logging.info( 'Scheduling D2H of %d prioritized jax.Array.', @@ -503,6 +536,17 @@ def _serialize_batch( prioritized_arrays, prioritized_infos, prioritized_args = zip( *prioritized ) + # Pre-filter to target replica's shards before D2H to avoid transferring + # replicated data + prioritized_arrays, prioritized_rslices = ( + replica_slices.filter_arrays_to_replica( + prioritized_arrays, + replica_id, + use_replica_parallel, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + ) + ) prioritized_arrays = dispatcher.device_to_host(prioritized_arrays) prioritized = [ (v, i, a) @@ -517,12 +561,18 @@ def _serialize_batch( ) all_infos = infos + async def _serialize(): for info in all_infos: await info.await_path_creation() if prioritized: arrays, infos, args = zip(*prioritized) - _serialize_batch(infos, args, arrays) + _serialize_batch( + infos, + args, + arrays, + precomputed_rslices=prioritized_rslices, + ) if deprioritized: assert device_host_max_bytes is not None for ( @@ -535,7 +585,20 @@ async def _serialize(): replica_id=replica_id, dispatcher=dispatcher, ): - _serialize_batch(b_infos, b_args, b_arrays) + b_arrays_list = list(b_arrays) + b_arrays_on_host, b_rslices = replica_slices.filter_arrays_to_replica( + b_arrays_list, + replica_id, + use_replica_parallel, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + ) + _serialize_batch( + b_infos, + b_args, + b_arrays_on_host, + precomputed_rslices=b_rslices, + ) return future.CommitFutureAwaitingContractedSignals( _serialize(), @@ -965,11 +1028,27 @@ def __init__( self._primary_host = primary_host self._replica_id = replica_id self._enable_write_sharding_file = enable_write_sharding_file - self._use_replica_parallel = ( - _get_default_use_replica_parallel() - if use_replica_parallel is None - else use_replica_parallel + if use_replica_parallel is None: + self._use_replica_parallel = _get_default_use_replica_parallel() + if self._use_replica_parallel and self._replica_id is None: + self._use_replica_parallel = False + logging.warning( + 'use_replica_parallel=True overridden to False because replica_id' + ' is None.' + ) + elif not self._use_replica_parallel: + logging.warning( + 'use_replica_parallel=None and running on GPU, defaulting to' + ' use_replica_parallel=False.' + ) + else: + self._use_replica_parallel = use_replica_parallel + print( + '>>> ArrayHandler initialized with' + f' use_replica_parallel={self._use_replica_parallel},' + f' dispatcher={type(dispatcher)} <<<' ) + self._min_slice_bytes_for_replica_parallel = ( min_slice_bytes_for_replica_parallel ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py index 384ce40b0..c78c84859 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py @@ -18,7 +18,7 @@ import dataclasses import functools import math -from typing import Optional, Sequence +from typing import cast, Optional, Sequence from absl import logging import jax @@ -28,7 +28,6 @@ from orbax.checkpoint._src.arrays import types from orbax.checkpoint._src.multihost import multihost - Shape = types.Shape Index = types.Index OptionalAxisAndShapeAndReplicaCount = tuple[ @@ -388,6 +387,58 @@ def maybe_pick_replica_parallel() -> Optional[Result]: ) +def filter_arrays_to_replica( + arrays: Sequence[jax.Array], + replica_id: int, + use_replica_parallel: bool, + min_slice_bytes_for_replica_parallel: Optional[int] = None, + max_replicas_for_replica_parallel: Optional[int] = None, +) -> tuple[Sequence[jax.Array], Sequence[ReplicaSlices]]: + """Pre-filter arrays to only let the data of this replica gets serialized.""" + filtered_arrays: list[jax.Array] = [] + rslices_list: list[ReplicaSlices] = [] + + for arr in arrays: + rslices = get_replica_slices( + arr, + replica_id, + use_replica_parallel, + min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel, + ) + rslices_list.append(rslices) + if not rslices.replica_slices: + filtered_arrays.append(arr) + continue + + has_sub_slicing = any( + rs.slice_args is not None for rs in rslices.replica_slices + ) + if not has_sub_slicing and len(rslices.replica_slices) == len( + arr.addressable_shards + ): + filtered_arrays.append(arr) + continue + + shard_data_list: list[jax.Array] = [] + for rslice in rslices.replica_slices: + shard_data_list.append(cast(jax.Array, rslice.data())) + + devices = np.array([sd.device for sd in shard_data_list]) + mesh = jax.sharding.Mesh(devices, ('d',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('d')) + + total_elements = sum(sd.size for sd in shard_data_list) + compact_shape = (total_elements,) + + flat_shards = [sd.reshape(-1) for sd in shard_data_list] + compact_arr = jax.make_array_from_single_device_arrays( + compact_shape, sharding, flat_shards + ) + filtered_arrays.append(compact_arr) + return filtered_arrays, rslices_list + + def transfer_arrays_to_host( arrays: Sequence[jax.Array], replica_id: Optional[int], diff --git a/checkpoint/orbax/checkpoint/_src/serialization/replica_slices_multiprocess_test.py b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices_multiprocess_test.py new file mode 100644 index 000000000..b65b6e26d --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices_multiprocess_test.py @@ -0,0 +1,76 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-process test for replica_slices.""" + +import jax +from jax import sharding +import numpy as np +from orbax.checkpoint._src.serialization import replica_slices +from orbax.checkpoint._src.testing import multiprocess_test + +PartitionSpec = sharding.PartitionSpec +NamedSharding = sharding.NamedSharding + + +def make_multi_device_array(): + """Creates a replicated array across the multi-host TPU mesh.""" + devices = np.array(jax.devices()).reshape((len(jax.devices()),)) + mesh = jax.sharding.Mesh(devices, axis_names=('x',)) + spec = PartitionSpec() + shape = (4096,) + arr = jax.make_array_from_callback( + shape, + NamedSharding(mesh, spec), + lambda idx: np.zeros( + tuple(len(range(*s.indices(shape[i]))) for i, s in enumerate(idx)) + ), + ) + return arr + + +class ReplicaSlicesMultiProcessTest(multiprocess_test.MultiProcessTest): + + def test_replica_parallel_sub_slicing(self): + arr = make_multi_device_array() + + filtered_arrays, _ = replica_slices.filter_arrays_to_replica( + [arr], + replica_id=0, + use_replica_parallel=True, + ) + + filtered = filtered_arrays[0] + # Under multi-process execution, each host holds only a fraction of the + # local shards in the mesh, so the filtered size on this host will be + # strictly smaller than the whole original array! + self.assertEqual(filtered.size, arr.size / 2) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py index d8df8ba44..9d5a97187 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py @@ -118,7 +118,10 @@ def test_get_replica_slices_replica_parallel(self, shape, expected_axis): self.assertEqual(rslice.slice_args.axis, expected_axis) @parameterized.product( - shape=[2 * 768, 2 * 1024,], + shape=[ + 2 * 768, + 2 * 1024, + ], partitioned=[False, True], ) def test_get_replica_slices_pinned_host_passes(self, shape, partitioned): @@ -438,5 +441,110 @@ def test_nbytes_mixed_slicing(self): ) self.assertEqual(rslices.nbytes, expected_nbytes) + +class FilterArraysToReplicaTest(parameterized.TestCase): + + @parameterized.product(partitioned=[False, True]) + def test_replicated_array_is_filtered(self, partitioned): + """A replicated array should have fewer shards after filtering.""" + if jax.device_count() < 4: + self.skipTest('Not enough devices to test') + arr, num_partitions, num_replicas = make_multi_device_array( + (64, 64), + partitioned=partitioned, + ) + + if num_replicas <= 1: + self.skipTest('Test requires multiple replicas.') + + filtered_arrays, rslices_list = replica_slices.filter_arrays_to_replica( + [arr], + replica_id=0, + use_replica_parallel=False, + ) + self.assertLen(filtered_arrays, 1) + self.assertLen(rslices_list, 1) + + filtered = filtered_arrays[0] + rslices = rslices_list[0] + + self.assertLess( + len(filtered.addressable_shards), len(arr.addressable_shards) + ) + self.assertLen(rslices.replica_slices, num_partitions) + + def test_non_replicated_array_passes_through(self): + """A fully-sharded array should be returned unchanged.""" + num_devices = len(jax.devices()) + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',)) + spec = jax.sharding.PartitionSpec('x') + sharding = jax.sharding.NamedSharding(mesh, spec) + arr = jax.device_put( + jax.random.normal(jax.random.PRNGKey(0), (num_devices * 8,)), + sharding, + ) + + filtered_arrays, _ = replica_slices.filter_arrays_to_replica( + [arr], + replica_id=0, + use_replica_parallel=False, + ) + self.assertLen(filtered_arrays, 1) + self.assertIs(filtered_arrays[0], arr) + + def test_replica_parallel_sub_slicing(self): + """With replica-parallel, filtered array should contain sub-sliced data.""" + if jax.device_count() < 4: + self.skipTest('Not enough devices to test') + arr, _, num_replicas = make_multi_device_array( + (64, 64), + partitioned=False, + ) + if num_replicas <= 1: + self.skipTest('Test requires multiple replicas.') + + filtered_arrays, rslices_list = replica_slices.filter_arrays_to_replica( + [arr], + replica_id=0, + use_replica_parallel=True, + ) + self.assertLen(filtered_arrays, 1) + + filtered = filtered_arrays[0] + rslices = rslices_list[0] + + # In a single-process test environment, all replica shards are local, so the + # filtered size matches the original. In a multi-process environment, only + # the process-local subset is retained, making it smaller. + if jax.process_count() > 1: + self.assertLess(filtered.size, arr.size) + else: + self.assertEqual(filtered.size, arr.size) + self.assertTrue( + any(rs.slice_args is not None for rs in rslices.replica_slices) + ) + + def test_metadata_preserved(self): + """ReplicaSlices metadata should have correct global_shape and sharding.""" + if jax.device_count() < 4: + self.skipTest('Not enough devices to test') + arr, _, num_replicas = make_multi_device_array( + (64, 64), + partitioned=False, + ) + if num_replicas <= 1: + self.skipTest('Test requires multiple replicas.') + + _, rslices_list = replica_slices.filter_arrays_to_replica( + [arr], + replica_id=0, + use_replica_parallel=False, + ) + rslices = rslices_list[0] + self.assertEqual(rslices.global_shape, arr.shape) + self.assertEqual(rslices.sharding, arr.sharding) + self.assertEqual(rslices.dtype, arr.dtype) + + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/worker_memory_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/worker_memory_utils.py index 99918b504..eb0d893c5 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/worker_memory_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/worker_memory_utils.py @@ -174,11 +174,8 @@ def next_memory_budgeted_batch( } else: device_to_worker_ids_map = _device_to_worker_ids(dispatcher) - # NOTE: We only transfer save shards with replica_id == replica_id, but we - # are actually redundantly transferring all shards, thanks to remote python - # / colcoated python. So we set replica_id to None, to estimate memory usage - # for all replicas. - replica_id = None + # NOTE: With pre-fitlering, only the target replica's shards are transfrred + # to host. use the actual replica_id for accurate memory esitmation def _no_worker_memory_usage() -> dict[int, int]: return {id: 0 for id in set(device_to_worker_ids_map.values())}