diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py index 2836d15f7..48795d2c4 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py @@ -21,6 +21,8 @@ import jax from orbax.checkpoint.experimental.v1 import training from orbax.checkpoint.experimental.v1._src.tree import types as tree_types +from pathwaysutils.experimental import concatenate_by_mesh_axis +from pathwaysutils.experimental import split_by_mesh_axis class Snapshotter: @@ -28,8 +30,9 @@ class Snapshotter: _snapshots: collections.deque[tuple[tree_types.PyTree, int]] - def __init__(self): + def __init__(self, *, replica_axis_index: int = 0): self._snapshots = collections.deque(maxlen=2) + self.replica_axis_index = replica_axis_index def save_pytree(self, step: int, state: Any) -> None: """Move arrays onto CPU worker devices.""" @@ -61,11 +64,40 @@ def load_pytree(self, abstract_state: Any) -> tree_types.PyTree: pinned_state, _ = self._snapshots[-1] + def is_replica_active(arr): + try: + jax.block_until_ready(arr) + return True + except jax.errors.JaxRuntimeError as _: + return False + + def get_active_pytree(x): + mesh_axis_name = x.sharding.mesh.axis_names[self.replica_axis_index] + all_replicas = split_by_mesh_axis.split_by_mesh_axis( + x, + mesh_axis_name, + ) + + active_replicas = [ + replica for replica in all_replicas if is_replica_active(replica) + ] + + reconstructed_state = concatenate_by_mesh_axis.concatenate_by_mesh_axis( + active_replicas, + mesh_axis_name, + ) + return reconstructed_state + + pinned_state = jax.tree.map(get_active_pytree, pinned_state) + # Re-shard on host to the target device mesh host_target_shardings = jax.tree.map( lambda x: x.sharding.with_memory_kind("pinned_host"), abstract_state ) - host_target_state = jax.device_put(pinned_state, host_target_shardings) + + host_target_state = jax.device_put( + pinned_state, host_target_shardings + ) # Move from host back to device (TPU) memory. restored_state = jax.device_put(