Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 177 additions & 35 deletions checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.serialization import worker_memory_utils
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

import tensorstore as ts

Pytree: TypeAlias = Any
ArrayRestoreArgs = jax_array_restore_args.ArrayRestoreArgs
Expand Down Expand Up @@ -200,34 +200,156 @@ async def _async_serialize_shardings(
await sharding_metadata_txn.commit_async()


def _get_replica_slices(
arrays: Sequence[jax.Array],
def _get_replica_slices_with_global(
arr: jax.Array,
g_sharding: jax.sharding.Sharding,
g_shape: Tuple[int, ...],
replica_id: int,
use_replica_parallel: bool,
min_slice_bytes_for_replica_parallel: int | None = None,
max_replicas_for_replica_parallel: int | None = None,
) -> Sequence[replica_slices.ReplicaSlices]:
"""Returns ReplicaSlices for arrays."""
rslices_per_array = [
replica_slices.get_replica_slices(
) -> replica_slices.ReplicaSlices:
"""Returns ReplicaSlices for an array with global sharding and shape."""
replica_count = replica_slices._sharding_num_replicas( # pylint: disable=protected-access
g_sharding, g_shape
)

shard0 = arr.addressable_shards[0]
axis = None
for axis_index, axis_size in enumerate(shard0.data.shape):
if axis_size % replica_count == 0:
axis = axis_index
break

if replica_count > 1 and axis is not None and use_replica_parallel:
local_shape = tuple(
axis_size // (replica_count if axis_index == axis else 1)
for axis_index, axis_size in enumerate(shard0.data.shape)
)
min_slice_bytes = min_slice_bytes_for_replica_parallel or 0
if np.prod(local_shape) * arr.dtype.itemsize >= min_slice_bytes:
rslices = []
for shard in arr.addressable_shards:
if shard.replica_id >= replica_count:
continue
size = local_shape[axis]
slize = shard.index[axis]
start = slize.start or 0

start_offset = shard.replica_id * size
end_offset = start_offset + size
new_slice = slice(start + start_offset, start + end_offset)
logging.info(
'[process=%d] _get_replica_slices: shard replica=%d, full'
' shape=%s, full_array_bytes=%s, slice=%s,'
' slice_target_bytes=%s',
multihost.process_index(),
shard.replica_id,
shard.data.shape,
shard.data.nbytes,
new_slice,
shard.data.nbytes // replica_count,
)

rslices.append(
replica_slices.ReplicaSlice(
index=shard.index[:axis]
+ (new_slice,)
+ shard.index[axis + 1 :],
unsliced_data=shard.data,
slice_args=replica_slices.SliceArgs(
start_offset, end_offset, axis
),
)
)
return replica_slices.ReplicaSlices(
global_shape=g_shape,
local_shape=local_shape,
sharding=g_sharding,
dtype=arr.dtype,
is_on_host=False,
replica_slices=rslices,
)
else:
return replica_slices.get_replica_slices(
arr,
replica_id,
use_replica_parallel,
False,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
)
for arr in arrays
]
else:
logging.info(
'[process=%d] _get_replica_slices standard: replica_count=%s,'
' use_replica_parallel=False, array_bytes=%s',
multihost.process_index(),
replica_count,
arr.nbytes if hasattr(arr, 'nbytes') else 0,
)
return replica_slices.get_replica_slices(
arr,
replica_id,
False,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
)


def _get_replica_slices(
arrays: Sequence[jax.Array],
replica_id: int,
use_replica_parallel: bool,
min_slice_bytes_for_replica_parallel: int | None = None,
max_replicas_for_replica_parallel: int | None = None,
global_shardings: Sequence[jax.sharding.Sharding] | None = None,
global_shapes: Sequence[Tuple[int, ...]] | None = None,
) -> Sequence[replica_slices.ReplicaSlices]:
"""Returns ReplicaSlices for arrays."""
rslices_per_array = []
for i, arr in enumerate(arrays):
if global_shardings and global_shapes:
rslices_per_array.append(
_get_replica_slices_with_global(
arr,
global_shardings[i],
global_shapes[i],
replica_id,
use_replica_parallel,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
)
)
else:
rslices_per_array.append(
replica_slices.get_replica_slices(
arr,
replica_id,
use_replica_parallel,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
)
)

# D2H copy is performed automatically as part of dispatcher call, but
# we must set properties correctly to pass later consistency checks.
def _get_sliced_host_data(rslice: replica_slices.ReplicaSlice) -> np.ndarray:
np_data = np.asarray(rslice.unsliced_data)
if rslice.slice_args is not None:
slicer = [slice(None)] * np_data.ndim
slicer[rslice.slice_args.axis] = slice(
rslice.slice_args.start_index, rslice.slice_args.limit_index
)
np_data = np_data[tuple(slicer)]
return np_data

return [
dataclasses.replace(
rslices,
is_on_host=True,
replica_slices=[
dataclasses.replace(
rslice,
unsliced_data=np.asarray(rslice.data()),
unsliced_data=_get_sliced_host_data(rslice),
slice_args=None,
)
for rslice in rslices.replica_slices
Expand All @@ -250,6 +372,8 @@ def _worker_serialize_arrays(
array_metadata_store: array_metadata_store_lib.Store | None,
enable_replica_parallel_separate_folder: bool,
ext_metadata: Dict[str, Any],
global_shardings: Sequence[jax.sharding.Sharding] | None = None,
global_shapes: Sequence[Tuple[int, ...]] | None = None,
):
"""Worker function to serialize arrays."""
rslices_per_array = _get_replica_slices(
Expand All @@ -258,6 +382,8 @@ def _worker_serialize_arrays(
use_replica_parallel,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
global_shardings=global_shardings,
global_shapes=global_shapes,
)

asyncio_utils.run_sync(
Expand Down Expand Up @@ -487,30 +613,45 @@ def _serialize_batch(
batch_args: Sequence[types.SaveArgs],
batch_arrays: Sequence[jax.Array],
):
ret = dispatcher.dispatch(
_worker_serialize_arrays,
input_arrays=batch_arrays,
func_kwargs={
'infos': batch_infos,
'args': batch_args,
'replica_id': replica_id,
'use_replica_parallel': use_replica_parallel,
'min_slice_bytes_for_replica_parallel': (
min_slice_bytes_for_replica_parallel
),
'max_replicas_for_replica_parallel': (
max_replicas_for_replica_parallel
),
'primary_host': primary_host,
'metadata_key': metadata_key,
'array_metadata_store': array_metadata_store,
'enable_replica_parallel_separate_folder': (
enable_replica_parallel_separate_folder
),
'ext_metadata': ext_metadata,
},
)
jax.block_until_ready(ret)
groups = {}
for idx, arr in enumerate(batch_arrays):
d_set = frozenset(arr.devices())
if d_set not in groups:
groups[d_set] = []
groups[d_set].append(idx)

rets = []
for indices in groups.values():
sub_arrays = [batch_arrays[i] for i in indices]
sub_infos = [batch_infos[i] for i in indices]
sub_args = [batch_args[i] for i in indices]

rets.append(dispatcher.dispatch(
_worker_serialize_arrays,
input_arrays=sub_arrays,
func_kwargs={
'infos': sub_infos,
'args': sub_args,
'replica_id': replica_id,
'use_replica_parallel': use_replica_parallel,
'min_slice_bytes_for_replica_parallel': (
min_slice_bytes_for_replica_parallel
),
'max_replicas_for_replica_parallel': (
max_replicas_for_replica_parallel
),
'primary_host': primary_host,
'metadata_key': metadata_key,
'array_metadata_store': array_metadata_store,
'enable_replica_parallel_separate_folder': (
enable_replica_parallel_separate_folder
),
'ext_metadata': ext_metadata,
'global_shardings': [batch_arrays[i].sharding for i in indices],
'global_shapes': [batch_arrays[i].shape for i in indices],
},
))
jax.block_until_ready(rets)

# Enqueue D2H operation for prioritized values.
if prioritized:
Expand All @@ -535,6 +676,7 @@ def _serialize_batch(
)

all_infos = infos

async def _serialize():
for info in all_infos:
await info.await_path_creation()
Expand Down
Loading