Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
import jax
from orbax.checkpoint.experimental.v1 import training
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
from pathwaysutils.experimental import concatenate_by_mesh_axis
from pathwaysutils.experimental import split_by_mesh_axis


class Snapshotter:
"""Manages asynchronous backups of JAX array states to pinned host memory."""

_snapshots: collections.deque[tuple[tree_types.PyTree, int]]

def __init__(self):
def __init__(self, *, replica_axis_index: int = 0):
self._snapshots = collections.deque(maxlen=2)
self.replica_axis_index = replica_axis_index

def save_pytree(self, step: int, state: Any) -> None:
"""Move arrays onto CPU worker devices."""
Expand Down Expand Up @@ -61,11 +64,40 @@ def load_pytree(self, abstract_state: Any) -> tree_types.PyTree:

pinned_state, _ = self._snapshots[-1]

def is_replica_active(arr):
try:
jax.block_until_ready(arr)
return True
except jax.errors.JaxRuntimeError as _:
return False

def get_active_pytree(x):
mesh_axis_name = x.sharding.mesh.axis_names[self.replica_axis_index]
all_replicas = split_by_mesh_axis.split_by_mesh_axis(
x,
mesh_axis_name,
)

active_replicas = [
replica for replica in all_replicas if is_replica_active(replica)
]

reconstructed_state = concatenate_by_mesh_axis.concatenate_by_mesh_axis(
active_replicas,
mesh_axis_name,
)
return reconstructed_state

pinned_state = jax.tree.map(get_active_pytree, pinned_state)

# Re-shard on host to the target device mesh
host_target_shardings = jax.tree.map(
lambda x: x.sharding.with_memory_kind("pinned_host"), abstract_state
)
host_target_state = jax.device_put(pinned_state, host_target_shardings)

host_target_state = jax.device_put(
pinned_state, host_target_shardings
)

# Move from host back to device (TPU) memory.
restored_state = jax.device_put(
Expand Down
Loading