Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,44 +252,62 @@ def unique_colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> tuple[jax.Device, ...]:
"""Returns one colocated CPU device per worker."""
logging.info('unique_colocated_cpu_devices: input devices=%s', devices)
all_cpu = tuple(cp.colocated_cpu_devices(tuple(devices)))
logging.info(
'unique_colocated_cpu_devices: colocated_cpu_devices returned=%s', all_cpu
)
unique_cpu = []
seen_ids = set()
for device in all_cpu:
if device.id in seen_ids:
continue
seen_ids.add(device.id)
unique_cpu.append(device)
logging.info('unique_colocated_cpu_devices: unique_cpu=%s', unique_cpu)
return tuple(unique_cpu)


def colocated_cpu_sharding(
sharding: jax.sharding.Sharding,
) -> jax.sharding.Sharding:
"""Returns a CPU sharding colocated with the given sharding."""
logging.info('colocated_cpu_sharding: input sharding=%s', sharding)
if isinstance(sharding, jax.sharding.SingleDeviceSharding):
cpu_devices = cp.colocated_cpu_devices(list(sharding.device_set))
return jax.sharding.SingleDeviceSharding(
result = jax.sharding.SingleDeviceSharding(
cpu_devices[0], memory_kind=sharding.memory_kind
)
logging.info(
'colocated_cpu_sharding: returning SingleDeviceSharding=%s', result
)
return result
if isinstance(sharding, jax.sharding.NamedSharding):
cpu_mesh = cp.colocated_cpu_devices(sharding.mesh)
return jax.sharding.NamedSharding(
result = jax.sharding.NamedSharding(
cpu_mesh, sharding.spec, memory_kind=sharding.memory_kind
)
logging.info('colocated_cpu_sharding: returning NamedSharding=%s', result)
return result
logging.error(
'colocated_cpu_sharding: unsupported sharding type=%s', type(sharding)
)
raise TypeError(
f'Sharding type {type(sharding)} not supported in to_colocated_python.'
)


def to_colocated_python(input_tree: PyTree) -> PyTree:
"""Copies a pytree of arrays to colocated CPU devices."""
logging.info(
'to_colocated_python: starting with tree structure=%s',
jax.tree.structure(input_tree),
)

def _get_sharding(x: Any) -> jax.sharding.Sharding | None:
if isinstance(x, jax.Array):
cpu_sharding = colocated_cpu_sharding(x.sharding)
logging.vlog(
1,
logging.info(
'Staging array from %s to colocated CPU sharding %s',
x.sharding,
cpu_sharding,
Expand All @@ -298,7 +316,9 @@ def _get_sharding(x: Any) -> jax.sharding.Sharding | None:
return None

cpu_sharding_tree = jax.tree.map(_get_sharding, input_tree)
return jax.device_put(input_tree, cpu_sharding_tree, may_alias=True)
result = jax.device_put(input_tree, cpu_sharding_tree, may_alias=True)
logging.info('to_colocated_python: finished device_put')
return result


def make_scalar_array_like(
Expand Down Expand Up @@ -332,8 +352,7 @@ def convert_array_restore_args(
"""Converts ArrayRestoreArgs to use colocated CPU devices."""
if restore_args.mesh is not None:
cpu_mesh = cp.colocated_cpu_devices(restore_args.mesh)
logging.vlog(
1,
logging.info(
'Converting restore mesh with axis names %s to colocated CPU mesh.',
restore_args.mesh.axis_names,
)
Expand All @@ -342,8 +361,7 @@ def convert_array_restore_args(
return restore_args
if isinstance(restore_args.sharding, jax.sharding.Sharding):
cpu_sharding = colocated_cpu_sharding(restore_args.sharding)
logging.vlog(
1,
logging.info(
'Converting restore sharding from %s to colocated CPU sharding %s',
restore_args.sharding,
cpu_sharding,
Expand All @@ -352,8 +370,7 @@ def convert_array_restore_args(
if isinstance(restore_args.sharding, sharding_metadata.ShardingMetadata):
sharding = restore_args.sharding.to_jax_sharding()
cpu_sharding = colocated_cpu_sharding(sharding)
logging.vlog(
1,
logging.info(
'Converting restore sharding metadata %s to colocated CPU sharding %s',
type(restore_args.sharding).__name__,
cpu_sharding,
Expand Down Expand Up @@ -388,24 +405,37 @@ def convert_single_replica_restore_args(

def transform_tree_shardings(input_tree: PyTree) -> Any:
"""Converts shardings/specs/restore-args/arrays to colocated CPU devices."""
logging.info('transform_tree_shardings: starting mapping')

def _transform_leaf_sharding(leaf: Any) -> Any:
if isinstance(leaf, jax.sharding.Sharding):
logging.info('transform_tree_shardings: converting Sharding=%s', leaf)
return colocated_cpu_sharding(leaf)
if isinstance(leaf, jax.ShapeDtypeStruct) and hasattr(leaf, 'sharding'):
logging.info(
'transform_tree_shardings: ShapeDtypeStruct sharding=%s',
leaf.sharding,
)
cpu_sharding = colocated_cpu_sharding(leaf.sharding)
return jax.ShapeDtypeStruct(
leaf.shape, leaf.dtype, sharding=cpu_sharding
)
if isinstance(leaf, jax_array_restore_args.SingleReplicaArrayRestoreArgs):
logging.info(
'transform_tree_shardings: SingleReplicaArrayRestoreArgs=%s', leaf
)
return convert_single_replica_restore_args(leaf)
if isinstance(leaf, jax_array_restore_args.ArrayRestoreArgs):
logging.info('transform_tree_shardings: ArrayRestoreArgs=%s', leaf)
return convert_array_restore_args(leaf)
if isinstance(leaf, jax.Array):
logging.info('transform_tree_shardings: Array of shape %s', leaf.shape)
return to_colocated_python(leaf)
return leaf

return jax.tree.map(_transform_leaf_sharding, input_tree)
result = jax.tree.map(_transform_leaf_sharding, input_tree)
logging.info('transform_tree_shardings: finished mapping')
return result


def to_final_specs(
Expand All @@ -416,8 +446,7 @@ def to_final_specs(

def _to_final_spec(leaf: Any, tpu_or_cpu_spec: Any) -> Any:
if isinstance(leaf, jax.Array) and hasattr(tpu_or_cpu_spec, 'sharding'):
logging.vlog(
1,
logging.info(
'Transferring array from %s to final sharding %s',
leaf.sharding,
tpu_or_cpu_spec.sharding,
Expand Down
55 changes: 48 additions & 7 deletions checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import abc
from collections.abc import Sequence
import time
from typing import Any, Callable

from absl import logging
Expand Down Expand Up @@ -95,9 +96,8 @@ def _make_dummy_result_array(


def _vlog_dispatch(fn: Callable[..., Any], dispatcher_name: str):
if logging.vlog_is_on(1):
logging.vlog(
1,
if True: # pylint: disable=using-constant-test
logging.info(
'Executing function %r via %s on process=%s/%s',
fn,
dispatcher_name,
Expand Down Expand Up @@ -228,13 +228,36 @@ def dispatch(

@cp.colocated_python
def _cp_wrapper(inp: PyTree) -> PyTree:
logging.info(
'Entering _cp_wrapper on process=%s/%s for func=%s',
multihost.process_index(),
multihost.process_count(),
func,
)
_vlog_dispatch(func, 'ColocatedPythonDispatcher')
args = (inp,) + cpu_args if is_input_arrays_provided else cpu_args
logging.info(
'_cp_wrapper: about to execute user function with %d args, %d kwargs',
len(args),
len(cpu_kwargs),
)
start_time = time.time()
if is_func_output_discarded:
func(*args, **cpu_kwargs)
return _make_dummy_result_array(inp)
res = _make_dummy_result_array(inp)
else:
return func(*args, **cpu_kwargs)
res = func(*args, **cpu_kwargs)
logging.info(
'Worker execution of %r took %f seconds',
func,
time.time() - start_time,
)
logging.info(
'Exiting _cp_wrapper on process=%s/%s',
multihost.process_index(),
multihost.process_count(),
)
return res

result_specs = result_specs or _make_dummy_result_array(
input_arrays, abstract=True
Expand All @@ -244,7 +267,25 @@ def _cp_wrapper(inp: PyTree) -> PyTree:
out_specs_fn=lambda _: cpu_result_specs
)

result = specialized_wrapper(self.to_colocated_python(input_arrays))
return self._to_final_specs(result, result_specs)
start_time = time.time()
input_on_cpu = self.to_colocated_python(input_arrays)
logging.info(
'to_colocated_python took %f seconds', time.time() - start_time
)

start_time = time.time()
result = specialized_wrapper(input_on_cpu)
logging.info(
'specialized_wrapper took %f seconds', time.time() - start_time
)

start_time = time.time()
final_result = self._to_final_specs(result, result_specs)
logging.info(
'_to_final_specs took %f seconds, result_specs=%s',
time.time() - start_time,
result_specs,
)
return final_result


24 changes: 8 additions & 16 deletions checkpoint/orbax/checkpoint/_src/multihost/multihost.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def initialize_distributed_to_device_ids():
]
assert None not in results
_DISTRIBUTED_TO_DEVICE_IDS = results
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] distributed_to_device_ids: %s',
own_distributed_id,
threading.current_thread().name,
Expand Down Expand Up @@ -162,8 +161,7 @@ def initialize_runtime_to_distributed_ids():
for key, distributed_id in ids:
runtime_id = int(key.split('/')[-1])
_RUNTIME_TO_DISTRIBUTED_ID[runtime_id] = int(distributed_id)
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] runtime_to_distributed_id: %s',
process_index(),
threading.current_thread().name,
Expand Down Expand Up @@ -266,8 +264,7 @@ def get_barrier_sync_fn(

def _fn(*, key: str, timeout_ms: int) -> None:
key = _unique_barrier_key(key)
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] Waiting at barrier: %s',
process_index(),
threading.current_thread().name,
Expand All @@ -276,16 +273,14 @@ def _fn(*, key: str, timeout_ms: int) -> None:
if processes is None:
client.wait_at_barrier(key, timeout_ms)
else:
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] Barrier processes: %s',
process_index(),
threading.current_thread().name,
barrier_processes,
)
client.wait_at_barrier(key, timeout_ms, process_ids=barrier_processes)
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] Done waiting at barrier: %s',
process_index(),
threading.current_thread().name,
Expand Down Expand Up @@ -339,8 +334,7 @@ def sync_global_processes(
synchronization.
"""
if should_skip_process_sync(processes):
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] Skipping global process sync, barrier'
' name: %s',
process_index(),
Expand All @@ -365,16 +359,14 @@ def sync_global_processes(
# Temporarily default to existing behavior to minimize risk of breakage.
if processes is None and not use_distributed_barrier:
key = _unique_barrier_key(name)
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] Begin jax/sync_global_devices("%s")',
process_index(),
threading.current_thread().name,
key,
)
multihost_utils.sync_global_devices(key)
logging.vlog(
1,
logging.info(
'[process=%s][thread=%s] Done jax/sync_global_devices("%s"): %s secs',
process_index(),
threading.current_thread().name,
Expand Down
6 changes: 2 additions & 4 deletions checkpoint/orbax/checkpoint/_src/multihost/multislice.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def _globalize_single_replica_arrays(
f' sharding.spec {sharding.spec}'
)
global_shape = (num_replicas,) + local_replica_shape
logging.vlog(
1,
logging.info(
'Globalizing array with local shape %s to Global shape: %s',
local_replica_shape,
global_shape,
Expand Down Expand Up @@ -264,8 +263,7 @@ def _globalize_single_replica_arrays(
zero_data = jnp.zeros(slice_shape, dtype=inp.dtype, device=d)
device_buffers.append(zero_data)

logging.vlog(
1,
logging.info(
'Device buffers: %r',
{d.device: d for d in device_buffers},
)
Expand Down
3 changes: 1 addition & 2 deletions checkpoint/orbax/checkpoint/_src/multihost/pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def group_devices_by_worker(
for d in devices:
key = _get_device_worker_key(d)
worker_devices[key].append(d)
logging.vlog(
1,
logging.info(
'Grouped %d devices into %d Pathways workers: %s',
len(devices),
len(worker_devices),
Expand Down
Loading
Loading