From 0d54e77a1614c4dea5191551bff7aad47065bcf0 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Mon, 20 Apr 2026 12:03:23 -0700 Subject: [PATCH] Not meant to be checked in. PiperOrigin-RevId: 902754830 --- .../_src/multihost/colocated_transport.py | 57 ++++++-- .../checkpoint/_src/multihost/dispatchers.py | 55 +++++++- .../checkpoint/_src/multihost/multihost.py | 24 ++-- .../checkpoint/_src/multihost/multislice.py | 6 +- .../checkpoint/_src/multihost/pathways.py | 3 +- .../_src/serialization/jax_array_handlers.py | 124 ++++++++++++------ .../_src/serialization/ocdbt_utils.py | 20 ++- .../_src/serialization/serialization.py | 8 +- .../_src/serialization/tensorstore_utils.py | 2 +- .../_src/serialization/type_handlers.py | 24 ++-- .../replica_parallel/llama-70b-v5p-64-pw.yaml | 40 ++++++ .../_src/testing/benchmarks/run_benchmarks.py | 1 + .../_src/testing/benchmarks/xpk/Dockerfile | 2 +- .../_src/testing/benchmarks/xpk/launch_xpk.py | 5 +- 14 files changed, 261 insertions(+), 110 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/replica_parallel/llama-70b-v5p-64-pw.yaml diff --git a/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py b/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py index 5c486d5b4..605d6d2e2 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py @@ -252,7 +252,11 @@ def unique_colocated_cpu_devices( devices: Sequence[jax.Device], ) -> tuple[jax.Device, ...]: """Returns one colocated CPU device per worker.""" + logging.info('unique_colocated_cpu_devices: input devices=%s', devices) all_cpu = tuple(cp.colocated_cpu_devices(tuple(devices))) + logging.info( + 'unique_colocated_cpu_devices: colocated_cpu_devices returned=%s', all_cpu + ) unique_cpu = [] seen_ids = set() for device in all_cpu: @@ -260,6 +264,7 @@ def unique_colocated_cpu_devices( continue seen_ids.add(device.id) unique_cpu.append(device) + logging.info('unique_colocated_cpu_devices: unique_cpu=%s', unique_cpu) return tuple(unique_cpu) @@ -267,16 +272,26 @@ def colocated_cpu_sharding( sharding: jax.sharding.Sharding, ) -> jax.sharding.Sharding: """Returns a CPU sharding colocated with the given sharding.""" + logging.info('colocated_cpu_sharding: input sharding=%s', sharding) if isinstance(sharding, jax.sharding.SingleDeviceSharding): cpu_devices = cp.colocated_cpu_devices(list(sharding.device_set)) - return jax.sharding.SingleDeviceSharding( + result = jax.sharding.SingleDeviceSharding( cpu_devices[0], memory_kind=sharding.memory_kind ) + logging.info( + 'colocated_cpu_sharding: returning SingleDeviceSharding=%s', result + ) + return result if isinstance(sharding, jax.sharding.NamedSharding): cpu_mesh = cp.colocated_cpu_devices(sharding.mesh) - return jax.sharding.NamedSharding( + result = jax.sharding.NamedSharding( cpu_mesh, sharding.spec, memory_kind=sharding.memory_kind ) + logging.info('colocated_cpu_sharding: returning NamedSharding=%s', result) + return result + logging.error( + 'colocated_cpu_sharding: unsupported sharding type=%s', type(sharding) + ) raise TypeError( f'Sharding type {type(sharding)} not supported in to_colocated_python.' ) @@ -284,12 +299,15 @@ def colocated_cpu_sharding( def to_colocated_python(input_tree: PyTree) -> PyTree: """Copies a pytree of arrays to colocated CPU devices.""" + logging.info( + 'to_colocated_python: starting with tree structure=%s', + jax.tree.structure(input_tree), + ) def _get_sharding(x: Any) -> jax.sharding.Sharding | None: if isinstance(x, jax.Array): cpu_sharding = colocated_cpu_sharding(x.sharding) - logging.vlog( - 1, + logging.info( 'Staging array from %s to colocated CPU sharding %s', x.sharding, cpu_sharding, @@ -298,7 +316,9 @@ def _get_sharding(x: Any) -> jax.sharding.Sharding | None: return None cpu_sharding_tree = jax.tree.map(_get_sharding, input_tree) - return jax.device_put(input_tree, cpu_sharding_tree, may_alias=True) + result = jax.device_put(input_tree, cpu_sharding_tree, may_alias=True) + logging.info('to_colocated_python: finished device_put') + return result def make_scalar_array_like( @@ -332,8 +352,7 @@ def convert_array_restore_args( """Converts ArrayRestoreArgs to use colocated CPU devices.""" if restore_args.mesh is not None: cpu_mesh = cp.colocated_cpu_devices(restore_args.mesh) - logging.vlog( - 1, + logging.info( 'Converting restore mesh with axis names %s to colocated CPU mesh.', restore_args.mesh.axis_names, ) @@ -342,8 +361,7 @@ def convert_array_restore_args( return restore_args if isinstance(restore_args.sharding, jax.sharding.Sharding): cpu_sharding = colocated_cpu_sharding(restore_args.sharding) - logging.vlog( - 1, + logging.info( 'Converting restore sharding from %s to colocated CPU sharding %s', restore_args.sharding, cpu_sharding, @@ -352,8 +370,7 @@ def convert_array_restore_args( if isinstance(restore_args.sharding, sharding_metadata.ShardingMetadata): sharding = restore_args.sharding.to_jax_sharding() cpu_sharding = colocated_cpu_sharding(sharding) - logging.vlog( - 1, + logging.info( 'Converting restore sharding metadata %s to colocated CPU sharding %s', type(restore_args.sharding).__name__, cpu_sharding, @@ -388,24 +405,37 @@ def convert_single_replica_restore_args( def transform_tree_shardings(input_tree: PyTree) -> Any: """Converts shardings/specs/restore-args/arrays to colocated CPU devices.""" + logging.info('transform_tree_shardings: starting mapping') def _transform_leaf_sharding(leaf: Any) -> Any: if isinstance(leaf, jax.sharding.Sharding): + logging.info('transform_tree_shardings: converting Sharding=%s', leaf) return colocated_cpu_sharding(leaf) if isinstance(leaf, jax.ShapeDtypeStruct) and hasattr(leaf, 'sharding'): + logging.info( + 'transform_tree_shardings: ShapeDtypeStruct sharding=%s', + leaf.sharding, + ) cpu_sharding = colocated_cpu_sharding(leaf.sharding) return jax.ShapeDtypeStruct( leaf.shape, leaf.dtype, sharding=cpu_sharding ) if isinstance(leaf, jax_array_restore_args.SingleReplicaArrayRestoreArgs): + logging.info( + 'transform_tree_shardings: SingleReplicaArrayRestoreArgs=%s', leaf + ) return convert_single_replica_restore_args(leaf) if isinstance(leaf, jax_array_restore_args.ArrayRestoreArgs): + logging.info('transform_tree_shardings: ArrayRestoreArgs=%s', leaf) return convert_array_restore_args(leaf) if isinstance(leaf, jax.Array): + logging.info('transform_tree_shardings: Array of shape %s', leaf.shape) return to_colocated_python(leaf) return leaf - return jax.tree.map(_transform_leaf_sharding, input_tree) + result = jax.tree.map(_transform_leaf_sharding, input_tree) + logging.info('transform_tree_shardings: finished mapping') + return result def to_final_specs( @@ -416,8 +446,7 @@ def to_final_specs( def _to_final_spec(leaf: Any, tpu_or_cpu_spec: Any) -> Any: if isinstance(leaf, jax.Array) and hasattr(tpu_or_cpu_spec, 'sharding'): - logging.vlog( - 1, + logging.info( 'Transferring array from %s to final sharding %s', leaf.sharding, tpu_or_cpu_spec.sharding, diff --git a/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py b/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py index 15c731297..75b6d5b7e 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py @@ -16,6 +16,7 @@ import abc from collections.abc import Sequence +import time from typing import Any, Callable from absl import logging @@ -95,9 +96,8 @@ def _make_dummy_result_array( def _vlog_dispatch(fn: Callable[..., Any], dispatcher_name: str): - if logging.vlog_is_on(1): - logging.vlog( - 1, + if True: # pylint: disable=using-constant-test + logging.info( 'Executing function %r via %s on process=%s/%s', fn, dispatcher_name, @@ -228,13 +228,36 @@ def dispatch( @cp.colocated_python def _cp_wrapper(inp: PyTree) -> PyTree: + logging.info( + 'Entering _cp_wrapper on process=%s/%s for func=%s', + multihost.process_index(), + multihost.process_count(), + func, + ) _vlog_dispatch(func, 'ColocatedPythonDispatcher') args = (inp,) + cpu_args if is_input_arrays_provided else cpu_args + logging.info( + '_cp_wrapper: about to execute user function with %d args, %d kwargs', + len(args), + len(cpu_kwargs), + ) + start_time = time.time() if is_func_output_discarded: func(*args, **cpu_kwargs) - return _make_dummy_result_array(inp) + res = _make_dummy_result_array(inp) else: - return func(*args, **cpu_kwargs) + res = func(*args, **cpu_kwargs) + logging.info( + 'Worker execution of %r took %f seconds', + func, + time.time() - start_time, + ) + logging.info( + 'Exiting _cp_wrapper on process=%s/%s', + multihost.process_index(), + multihost.process_count(), + ) + return res result_specs = result_specs or _make_dummy_result_array( input_arrays, abstract=True @@ -244,7 +267,25 @@ def _cp_wrapper(inp: PyTree) -> PyTree: out_specs_fn=lambda _: cpu_result_specs ) - result = specialized_wrapper(self.to_colocated_python(input_arrays)) - return self._to_final_specs(result, result_specs) + start_time = time.time() + input_on_cpu = self.to_colocated_python(input_arrays) + logging.info( + 'to_colocated_python took %f seconds', time.time() - start_time + ) + + start_time = time.time() + result = specialized_wrapper(input_on_cpu) + logging.info( + 'specialized_wrapper took %f seconds', time.time() - start_time + ) + + start_time = time.time() + final_result = self._to_final_specs(result, result_specs) + logging.info( + '_to_final_specs took %f seconds, result_specs=%s', + time.time() - start_time, + result_specs, + ) + return final_result diff --git a/checkpoint/orbax/checkpoint/_src/multihost/multihost.py b/checkpoint/orbax/checkpoint/_src/multihost/multihost.py index 3978c00d2..4273df71a 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/multihost.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/multihost.py @@ -128,8 +128,7 @@ def initialize_distributed_to_device_ids(): ] assert None not in results _DISTRIBUTED_TO_DEVICE_IDS = results - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] distributed_to_device_ids: %s', own_distributed_id, threading.current_thread().name, @@ -162,8 +161,7 @@ def initialize_runtime_to_distributed_ids(): for key, distributed_id in ids: runtime_id = int(key.split('/')[-1]) _RUNTIME_TO_DISTRIBUTED_ID[runtime_id] = int(distributed_id) - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] runtime_to_distributed_id: %s', process_index(), threading.current_thread().name, @@ -266,8 +264,7 @@ def get_barrier_sync_fn( def _fn(*, key: str, timeout_ms: int) -> None: key = _unique_barrier_key(key) - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] Waiting at barrier: %s', process_index(), threading.current_thread().name, @@ -276,16 +273,14 @@ def _fn(*, key: str, timeout_ms: int) -> None: if processes is None: client.wait_at_barrier(key, timeout_ms) else: - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] Barrier processes: %s', process_index(), threading.current_thread().name, barrier_processes, ) client.wait_at_barrier(key, timeout_ms, process_ids=barrier_processes) - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] Done waiting at barrier: %s', process_index(), threading.current_thread().name, @@ -339,8 +334,7 @@ def sync_global_processes( synchronization. """ if should_skip_process_sync(processes): - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] Skipping global process sync, barrier' ' name: %s', process_index(), @@ -365,16 +359,14 @@ def sync_global_processes( # Temporarily default to existing behavior to minimize risk of breakage. if processes is None and not use_distributed_barrier: key = _unique_barrier_key(name) - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] Begin jax/sync_global_devices("%s")', process_index(), threading.current_thread().name, key, ) multihost_utils.sync_global_devices(key) - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] Done jax/sync_global_devices("%s"): %s secs', process_index(), threading.current_thread().name, diff --git a/checkpoint/orbax/checkpoint/_src/multihost/multislice.py b/checkpoint/orbax/checkpoint/_src/multihost/multislice.py index 3fb5a44a2..163b79500 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/multislice.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/multislice.py @@ -233,8 +233,7 @@ def _globalize_single_replica_arrays( f' sharding.spec {sharding.spec}' ) global_shape = (num_replicas,) + local_replica_shape - logging.vlog( - 1, + logging.info( 'Globalizing array with local shape %s to Global shape: %s', local_replica_shape, global_shape, @@ -264,8 +263,7 @@ def _globalize_single_replica_arrays( zero_data = jnp.zeros(slice_shape, dtype=inp.dtype, device=d) device_buffers.append(zero_data) - logging.vlog( - 1, + logging.info( 'Device buffers: %r', {d.device: d for d in device_buffers}, ) diff --git a/checkpoint/orbax/checkpoint/_src/multihost/pathways.py b/checkpoint/orbax/checkpoint/_src/multihost/pathways.py index 66b746f2d..4b69162b6 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/pathways.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/pathways.py @@ -110,8 +110,7 @@ def group_devices_by_worker( for d in devices: key = _get_device_worker_key(d) worker_devices[key].append(d) - logging.vlog( - 1, + logging.info( 'Grouped %d devices into %d Pathways workers: %s', len(devices), len(worker_devices), diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 7b75807ee..11eb688b8 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -234,6 +234,7 @@ def _worker_serialize_arrays( ext_metadata: Dict[str, Any], ): """Worker function to serialize arrays.""" + start_time = time.time() rslices_per_array = _get_replica_slices( arrays, replica_id, @@ -241,7 +242,9 @@ def _worker_serialize_arrays( min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, ) + logging.info('_get_replica_slices took %f seconds', time.time() - start_time) + start_time = time.time() asyncio_utils.run_sync( _async_serialize_replica_slices( rslices_per_array, @@ -255,6 +258,10 @@ def _worker_serialize_arrays( ext_metadata=ext_metadata, ) ) + logging.info( + '_async_serialize_replica_slices (run_sync) took %f seconds', + time.time() - start_time, + ) def _is_prioritized_for_saving(info: types.ParamInfo) -> bool: @@ -469,30 +476,44 @@ def _serialize_batch( batch_args: Sequence[types.SaveArgs], batch_arrays: Sequence[jax.Array], ): - ret = dispatcher.dispatch( - _worker_serialize_arrays, - input_arrays=batch_arrays, - func_kwargs={ - 'infos': batch_infos, - 'args': batch_args, - 'replica_id': replica_id, - 'use_replica_parallel': use_replica_parallel, - 'min_slice_bytes_for_replica_parallel': ( - min_slice_bytes_for_replica_parallel - ), - 'max_replicas_for_replica_parallel': ( - max_replicas_for_replica_parallel - ), - 'primary_host': primary_host, - 'metadata_key': metadata_key, - 'array_metadata_store': array_metadata_store, - 'enable_replica_parallel_separate_folder': ( - enable_replica_parallel_separate_folder - ), - 'ext_metadata': ext_metadata, - }, + logging.info( + '_serialize_batch: starting dispatch for %d arrays with' + ' dispatcher %s', + len(batch_arrays), + dispatcher.name(), ) - jax.block_until_ready(ret) + try: + ret = dispatcher.dispatch( + _worker_serialize_arrays, + input_arrays=batch_arrays, + func_kwargs={ + 'infos': batch_infos, + 'args': batch_args, + 'replica_id': replica_id, + 'use_replica_parallel': use_replica_parallel, + 'min_slice_bytes_for_replica_parallel': ( + min_slice_bytes_for_replica_parallel + ), + 'max_replicas_for_replica_parallel': ( + max_replicas_for_replica_parallel + ), + 'primary_host': primary_host, + 'metadata_key': metadata_key, + 'array_metadata_store': array_metadata_store, + 'enable_replica_parallel_separate_folder': ( + enable_replica_parallel_separate_folder + ), + 'ext_metadata': ext_metadata, + }, + ) + logging.info( + '_serialize_batch: successfully dispatched, blocking until ready' + ) + jax.block_until_ready(ret) + logging.info('_serialize_batch: finished block_until_ready') + except Exception as e: + logging.error('_serialize_batch: dispatcher failed with error: %s', e) + raise # Enqueue D2H operation for prioritized values. if prioritized: @@ -518,12 +539,28 @@ def _serialize_batch( all_infos = infos async def _serialize(): + logging.info( + '_serialize (async): awaiting path creation for %d items', + len(all_infos), + ) for info in all_infos: await info.await_path_creation() + logging.info('_serialize (async): finished awaiting path creation') if prioritized: + logging.info( + '_serialize (async): processing prioritized batch of size %d', + len(prioritized), + ) arrays, infos, args = zip(*prioritized) _serialize_batch(infos, args, arrays) + logging.info( + '_serialize (async): finished processing prioritized batch' + ) if deprioritized: + logging.info( + '_serialize (async): processing deprioritized items with limit=%s', + device_host_max_bytes, + ) assert device_host_max_bytes is not None for ( b_arrays, @@ -535,7 +572,15 @@ async def _serialize(): replica_id=replica_id, dispatcher=dispatcher, ): + logging.info( + '_serialize (async): processing deprioritized chunk of size %d', + len(b_arrays), + ) _serialize_batch(b_infos, b_args, b_arrays) + logging.info( + '_serialize (async): finished processing deprioritized chunk' + ) + logging.info('_serialize (async): all batches processed successfully') return future.CommitFutureAwaitingContractedSignals( _serialize(), @@ -591,16 +636,15 @@ async def _async_serialize_replica_slices( tspec = array_write_spec.json ts_context = info.ts_context - if logging.vlog_is_on(1): - logging.vlog(1, 'info: %s', info) - logging.vlog(1, 'arg: %s', arg) - logging.vlog( - 1, + if True: # pylint: disable=using-constant-test + logging.info('info: %s', info) + logging.info('arg: %s', arg) + logging.info( 'value.global_shape: %s, value.sharding: %s', value.global_shape, value.sharding, ) - logging.vlog(1, 'tspec: %s', tspec) + logging.info('tspec: %s', tspec) write_coros.append( serialization.async_serialize_from_host( @@ -622,7 +666,12 @@ async def _async_serialize_replica_slices( ) ) + start_time = time.time() await asyncio.gather(*write_coros) + logging.info( + '_async_serialize_replica_slices: asyncio.gather took %f seconds', + time.time() - start_time, + ) if ocdbt_transaction is not None: await ocdbt_transaction.commit_async() @@ -634,7 +683,7 @@ def _wrap_random_key_data( ) -> Sequence[jax.Array]: """Parse array_metadatas and wrap deserialized_arrays as random keys.""" - logging.vlog(1, 'array_metadatas = %s', array_metadatas) + logging.info('array_metadatas = %s', array_metadatas) if not isinstance(array_metadatas, Dict): raise ValueError( 'Expecting array_metadatas to be a "Dict" but got' @@ -657,8 +706,7 @@ def _wrap_random_key_data( if impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL): # pytype: disable=attribute-error deserialized_arrays[i] = jax.random.wrap_key_data(v, impl=impl) - logging.vlog( - 1, + logging.info( '%s: recreated as a random key: %s', info.name, deserialized_arrays[i], @@ -786,12 +834,12 @@ async def _async_deserialize( if jax.dtypes.issubdtype(arg.dtype, jax.dtypes.prng_key) else arg.dtype ) - if logging.vlog_is_on(1): - logging.vlog(1, 'tspec = %s', tspec) - logging.vlog(1, 'info = %s', info) - logging.vlog(1, 'arg = %s', arg) - logging.vlog(1, 'dtype = %s', dtype) - logging.vlog(1, 'sharding = %s', sharding) + if True: # pylint: disable=using-constant-test + logging.info('tspec = %s', tspec) + logging.info('info = %s', info) + logging.info('arg = %s', arg) + logging.info('dtype = %s', dtype) + logging.info('sharding = %s', sharding) deserialize_ops.append( serialization.async_deserialize( sharding, diff --git a/checkpoint/orbax/checkpoint/_src/serialization/ocdbt_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/ocdbt_utils.py index b95f03f78..f7415b407 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/ocdbt_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/ocdbt_utils.py @@ -71,9 +71,8 @@ async def _validate_params( without_zarray = set() for ts_param in raw_ts_params: ts_param = ts_param.decode('utf-8') - if logging.vlog_is_on(1): - logging.vlog( - 1, + if True: # pylint: disable=using-constant-test + logging.info( '[process=%s][thread=%s] Validating raw param: %s', process_index, current_thread_name, @@ -85,9 +84,8 @@ async def _validate_params( # a/.zarray -> a ts_param = re.sub(_ZARRAY_SUFFIX_RE, '', ts_param) with_zarray.add(ts_param) - if logging.vlog_is_on(1): - logging.vlog( - 1, + if True: # pylint: disable=using-constant-test + logging.info( '[process=%s][thread=%s] Collecting param with .zarray: %s', process_index, current_thread_name, @@ -96,9 +94,8 @@ async def _validate_params( else: # b -> b without_zarray.add(ts_param) - if logging.vlog_is_on(1): - logging.vlog( - 1, + if True: # pylint: disable=using-constant-test + logging.info( '[process=%s][thread=%s] Collecting param without .zarray: %s', process_index, current_thread_name, @@ -106,8 +103,7 @@ async def _validate_params( ) unique = with_zarray | without_zarray - logging.vlog( - 1, + logging.info( '[process=%s][thread=%s] Validating params in TensorStore KvStore.', process_index, current_thread_name, @@ -162,7 +158,7 @@ async def merge_ocdbt_per_process_files( directory.as_posix(), process_dir.name, ) - logging.vlog(1, 'child_kvstore_tspec: %s', child_kvstore_tspec) + logging.info('child_kvstore_tspec: %s', child_kvstore_tspec) open_ops.append(ts_utils.open_kv_store(child_kvstore_tspec, ts_context)) if not open_ops: # No per-process OCDBT checkpoint found! logging.warning( diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py index 0f0accd18..ae9667802 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py @@ -19,8 +19,9 @@ from collections.abc import Mapping import os import re +import time from typing import Any, Dict, Optional, Sequence, Union - +from absl import logging import jax from jax.experimental import layout import jax.numpy as jnp @@ -263,7 +264,12 @@ async def write_fragment(fragment: fragments.ConcreteFragment): write_fragment(fragment) for fragment in rslices_on_host.to_fragments().fragments ] + start_time = time.time() await asyncio.gather(*write_coros) + logging.info( + 'async_serialize_from_host: asyncio.gather took %f seconds', + time.time() - start_time, + ) def estimate_write_memory_footprint(arr: np.ndarray) -> int: diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py index 1013584d3..5c4a724a2 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py @@ -865,6 +865,6 @@ def print_ts_debug_data(key: str | None, infos: Sequence[types.ParamInfo]): ] for metrics in ts_metrics: - logging.vlog(1, 'ts_metric: %s', metrics) + logging.info('ts_metric: %s', metrics) return json.dumps(ts_metrics) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index cc5727f2a..e6a1b0301 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -138,10 +138,10 @@ async def _background_serialize( metadata_key=self._metadata_key, ) tspec = array_write_spec.json - if logging.vlog_is_on(1): - logging.vlog(1, 'tspec = %s', tspec) - logging.vlog(1, 'infos = %s', info) - logging.vlog(1, 'args = %s', arg) + if True: # pylint: disable=using-constant-test + logging.info('tspec = %s', tspec) + logging.info('infos = %s', info) + logging.info('args = %s', arg) if multihost.process_index() == 0: ts_context = info.ts_context write_coros.append(self._open_and_write(value, tspec, ts_context)) @@ -157,7 +157,7 @@ async def serialize( args = args or [types.SaveArgs()] * len(values) types.check_input_arguments(values, infos, args) check_array_values(values, infos) - if logging.vlog_is_on(1): + if True: # pylint: disable=using-constant-test ts_utils.print_ts_debug_data(self._metadata_key, infos) copied_values = [copy.deepcopy(v) for v in values] return [ @@ -193,10 +193,10 @@ async def deserialize( ) tspec = array_read_spec.json - if logging.vlog_is_on(1): - logging.vlog(1, 'tspec = %s', tspec) - logging.vlog(1, 'infos = %s', infos) - logging.vlog(1, 'args = %s', args) + if True: # pylint: disable=using-constant-test + logging.info('tspec = %s', tspec) + logging.info('infos = %s', infos) + logging.info('args = %s', args) open_futures += [ ts.open(ts.Spec(tspec), open=True, context=info.ts_context) ] @@ -204,10 +204,10 @@ async def deserialize( read_ops = [t.read() for t in tensorstores] ret = await asyncio.gather(*read_ops) - if logging.vlog_is_on(1): + if True: # pylint: disable=using-constant-test for a in ret: - logging.vlog( - 1, 'restored ndarray.shape = %s, array.dtype = %s', a.shape, a.dtype + logging.info( + 'restored ndarray.shape = %s, array.dtype = %s', a.shape, a.dtype ) ts_utils.print_ts_debug_data(self._metadata_key, infos) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/replica_parallel/llama-70b-v5p-64-pw.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/replica_parallel/llama-70b-v5p-64-pw.yaml new file mode 100644 index 000000000..38aefb9b3 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/replica_parallel/llama-70b-v5p-64-pw.yaml @@ -0,0 +1,40 @@ +# The name for the entire test suite run. +suite_name: "Llama 3.1 70B v5p-64" +num_repeats: 1 + +mesh_configs: + - mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 4, "tensor": 1} + - mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 8, "tensor": 1} + - mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 16, "tensor": 1} + - mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 32, "tensor": 1} + - mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 64, "tensor": 1} + - mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 128, "tensor": 1} + - mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 256, "tensor": 1} + +# The checkpoint configuration, shared across all generated benchmarks. +checkpoint_configs: + # - path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items" + # sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-64-data-1-fsdp-32-tensor-1/abstract_state.json" + - spec: + # # a_1d: {dtype: "float32", shape: [32], sharding: [null]} + # # b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]} + array_1gb: {dtype: "float32", shape: [8192, 1024, 32], sharding: ["tensor", fsdp]} # 1gb + # # custom_array: {dtype: "float32", shape: [8192, 1024, 1024], sharding: ["fsdp", tensor]} # 32gb + + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.pytree_checkpoint_benchmark.PyTreeCheckpointBenchmark" + # - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_colocated_python: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py index 02cb0c464..6c9b672e7 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py @@ -30,6 +30,7 @@ from etils import epath import jax from orbax.checkpoint._src.testing.benchmarks.core import config_parsing + try: import pathwaysutils # pylint: disable=g-import-not-at-top diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile index b9bfcebac..dcd6e85b6 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile @@ -100,7 +100,7 @@ RUN if [ "$JAX_VERSION" = "newest" ]; then \ # 4. Install Orbax from Source WORKDIR /app/orbax_repo/checkpoint RUN pip install --no-cache-dir . -RUN pip install pathwaysutils tensorboard +RUN pip install pathwaysutils tensorboard cloudpickle # 5. Environment Setup # Set PYTHONPATH so 'import orbax' works from the correctly mapped directory diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py index d29db90f3..a88574c5c 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py @@ -684,6 +684,7 @@ def construct_workload_command( if enable_pathways: env_vars = [ 'export JAX_PLATFORMS=proxy', + 'export JAX_BACKEND_TARGET=grpc://${PATHWAYS_HEAD}:29000', 'export ENABLE_PATHWAYS_PERSISTENCE=1', 'export ENABLE_PJRT_COMPATIBILITY=true', ] @@ -732,8 +733,8 @@ def construct_workload_command( python_cmd += ' --jax_cpu_collectives_implementation=gloo' if enable_pathways: python_cmd = ( - 'python3 -c "import pathwaysutils;' - ' pathwaysutils.initialize()" && ' + 'python3 -c "import pathwaysutils; pathwaysutils.initialize();' + " print('Pathwaysutils initialized.')\" && " + python_cmd )