Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions astraflow/core/weight_manager/tests/test_megatron_hf_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Integration test: Megatron HF-export -> WeightManager buffer -> HF tensors.

Validates the full PR2/PR3 path without RaaS:
1. MegatronEngine.get_hf_weight_metadata() sizes the buffer.
2. WeightManager.offload(export_hf_named_params()) streams HF tensors into
the shared-memory double buffer (writer rank only).
3. We read the buffer back, reinterpret per tensors_meta, and assert it
equals the reference HF checkpoint bit-for-bit.

This proves the bytes the sender will TCP to RaaS (full mode) are correct,
and that they live in HF layout (so the sender's HF-space delta is valid).

Run:
torchrun --nproc_per_node=<N> \
astraflow/core/weight_manager/tests/test_megatron_hf_offload.py \
--model /shared/models/Qwen3-0.6B --tp 2 --pp 1 --ep 1
"""

from __future__ import annotations

import argparse
import glob
import os
import sys

import torch
import torch.distributed as dist


def _ref(model_path):
from safetensors.torch import load_file

ref = {}
for f in sorted(glob.glob(os.path.join(model_path, "*.safetensors"))):
ref.update(load_file(f))
return ref


def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True)
ap.add_argument("--tp", type=int, default=1)
ap.add_argument("--pp", type=int, default=1)
ap.add_argument("--ep", type=int, default=1)
args = ap.parse_args()

from astraflow.train_worker.api.alloc_mode import ParallelStrategy
from astraflow.train_worker.api.cli_args import TrainEngineConfig
from astraflow.train_worker.api.io_struct import FinetuneSpec
from astraflow.train_worker.engine.megatron_engine import MegatronEngine

world = int(os.environ["WORLD_SIZE"])
dp = world // (args.tp * args.pp * args.ep)

engine = MegatronEngine(TrainEngineConfig(path=args.model, dtype="bfloat16"))
engine.create_process_group(
parallel_strategy=ParallelStrategy(
data_parallel_size=dp,
tensor_parallel_size=args.tp,
pipeline_parallel_size=args.pp,
expert_parallel_size=args.ep,
)
)
engine.initialize(
addr=None,
ft_spec=FinetuneSpec(total_train_epochs=1, dataset_size=1, train_batch_size=1),
)

rank = dist.get_rank()

# 1. Metadata (lockstep on all ranks).
hf_meta = engine.get_hf_weight_metadata()

# 2. Build a minimal WeightManager-like buffer write on the writer rank,
# reusing the real _offload_megatron_hf logic. We construct a real
# WeightManager but stub out the sender (no subprocess) by writing
# into a plain CPU tensor of the right size.
from math import prod

sizes = [(n, prod(sh) * (2 if dt == "bfloat16" else 4)) for n, (sh, dt) in hf_meta]
total = sum(s for _, s in sizes)

# Only rank 0 holds the "buffer" and checks; others just drive collectives.
is_writer = rank == 0
buf = torch.zeros(2 * total, dtype=torch.uint8) if is_writer else None

# Stream export and write to buf[half 0] in order (mirrors
# WeightManager._offload_megatron_hf with inactive_buf_idx=0).
offset = 0
written = {}
for name, tensor in engine.export_hf_named_params():
nbytes = tensor.numel() * tensor.element_size()
if is_writer:
u8 = tensor.contiguous().view(-1).view(torch.uint8)
buf[offset : offset + nbytes].copy_(u8)
written[name] = (offset, list(tensor.shape), str(tensor.dtype))
offset += nbytes

result = 0
if is_writer:
ref = _ref(args.model)
import json

tie = json.load(open(os.path.join(args.model, "config.json"))).get(
"tie_word_embeddings", False
)
# Read back each tensor from the buffer per metadata and compare.
off = 0
nbad = 0
nchk = 0
for name, (shape, dt) in hf_meta:
numel = prod(shape)
nbytes = numel * (2 if dt == "bfloat16" else 4)
raw = buf[off : off + nbytes]
tdtype = torch.bfloat16 if dt == "bfloat16" else torch.float32
t = raw.view(tdtype).view(*shape) if shape else raw.view(tdtype)
off += nbytes
if name not in ref:
if not (tie and name == "lm_head.weight"):
print(f"[FAIL] {name} not in ref", flush=True)
nbad += 1
continue
r = ref[name].to(torch.bfloat16)
if not torch.equal(t, r):
md = (t.float() - r.float()).abs().max().item()
print(f"[FAIL] {name} max|diff|={md:.3e}", flush=True)
nbad += 1
nchk += 1
print(
f"\n=== buffer roundtrip: total_bytes={total} checked={nchk} bad={nbad} ===",
flush=True,
)
result = 0 if nbad == 0 else 1

res_t = torch.tensor([result], device=f"cuda:{os.environ.get('LOCAL_RANK', 0)}")
dist.all_reduce(res_t, op=dist.ReduceOp.MAX)
if rank == 0:
print("PASS" if res_t.item() == 0 else "FAIL", flush=True)
engine.destroy()
return int(res_t.item())


if __name__ == "__main__":
sys.exit(main())
124 changes: 116 additions & 8 deletions astraflow/core/weight_manager/weight_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@

logger = logging.getLogger(__name__)

_DTYPE_SIZES = {
"float32": 4, "float16": 2, "bfloat16": 2,
"int64": 8, "int32": 4, "int16": 2, "int8": 1, "uint8": 1,
}


def _nbytes(shape: List[int], dtype: str) -> int:
from math import prod

return int(prod(shape)) * _DTYPE_SIZES.get(dtype, 2)


class WeightManager:
"""Single owner of all weight transfer state and logic.
Expand Down Expand Up @@ -85,6 +96,7 @@ def initialize(
global_rank: int,
lora_config: dict | None = None,
megatron_metadata: Optional[dict] = None,
megatron_hf_meta: Optional[list] = None,
dp_replicate_rank: int = 0,
) -> None:
"""Initialize buffers and start sender agent.
Expand All @@ -103,11 +115,19 @@ def initialize(
``target_modules``). Forwarded to the sender agent so the
RaaS receiver can save weights in PEFT adapter format.
megatron_metadata : dict, optional
If provided, enables Megatron shard-direct mode. Contains
``tp_size``, ``tp_rank``, ``dp_rank``, ``shard_specs`` (per-param
TP metadata), and ``conversion_config`` (for CPU-side HF conversion
in the sender agent). Buffer layout is computed from full
(gathered) param sizes; each TP rank writes only its shard.
(Legacy TP-only shard-direct mode — superseded by
``megatron_hf_meta``.) If provided, the sender reassembles raw
TP shards into HF format on CPU. Only correct for PP=1/EP=1 and
cannot compute deltas in HF space; kept for backward compat.
megatron_hf_meta : list, optional
Megatron **HF-export** mode (preferred). The ordered HF weight
layout ``[(hf_name, (shape, dtype_str)), ...]`` from
``hf_weight_metadata``. The buffer is sized for the full HF
model and ``offload`` writes already-converted HF tensors (from
``export_hf_named_params``) on the DP-head rank. Because the
buffer holds HF bytes, the sender's standard full/delta path is
correct under any TP/PP/EP/VPP combination — see
``docs/en/architecture/megatron-weight-sync.md``.
dp_replicate_rank : int
HSDP replica group index. 0 = primary replica (owns the shm
buffer and offloads weights). >0 = secondary replica (skips
Expand All @@ -117,8 +137,18 @@ def initialize(
self._global_rank = global_rank
self._hsdp_replica_rank = dp_replicate_rank
self._megatron_metadata = megatron_metadata

if megatron_metadata is not None:
self._megatron_hf_meta = megatron_hf_meta

if megatron_hf_meta is not None:
# HF-export mode: buffer holds the full HF model in HF layout.
# The sender treats it exactly like FSDP (no reassembly, delta
# in HF space), so megatron_metadata stays None for the sender.
meta_size = [
(name, _nbytes(shape, dtype))
for name, (shape, dtype) in megatron_hf_meta
]
tensors_meta = list(megatron_hf_meta)
elif megatron_metadata is not None:
meta_size, tensors_meta = self._compute_megatron_buffer_layout(
megatron_metadata["shard_specs"]
)
Expand All @@ -128,11 +158,15 @@ def initialize(

# Only the primary HSDP replica (replica_rank=0) runs the sender
# agent and owns the shm buffer. Secondary replicas skip entirely.
# In HF-export mode the buffer already holds HF bytes, so the sender
# runs the plain (FSDP) path — pass megatron_metadata=None.
if local_rank == 0 and dp_replicate_rank == 0:
self._start_sender_agent(
meta_size, tensors_meta,
lora_config=lora_config,
megatron_metadata=megatron_metadata,
megatron_metadata=(
None if megatron_hf_meta is not None else megatron_metadata
),
)

self._broadcast_shm_buffer()
Expand Down Expand Up @@ -346,6 +380,13 @@ def offload(
dict
Weight transfer metrics for wandb logging. Empty on non-rank-0.
"""
# Megatron HF-export mode: ``named_params`` is a fresh generator that
# yields gathered HF tensors. It must be streamed (not list()-ed) and
# iterated in lockstep on every rank (it runs TP/PP/EP collectives),
# but only the writer rank copies into the buffer.
if self._megatron_hf_meta is not None:
return self._offload_megatron_hf(named_params, version, rank, world_size)

params_list = list(named_params)

# Guard: wait if previous delta is still reading the inactive half.
Expand Down Expand Up @@ -429,6 +470,73 @@ def offload(

return metrics

def _offload_megatron_hf(
self,
hf_named_params: Iterator[Tuple[str, torch.Tensor]],
version: int,
rank: int,
world_size: int,
) -> dict:
"""Stream gathered HF tensors into the buffer (Megatron HF-export mode).
Every rank iterates ``hf_named_params`` in lockstep — it drives the
TP/PP/EP collectives inside ``export_hf_named_params`` — but only the
writer rank (global rank 0, which owns the shm buffer) copies the
yielded tensors into the inactive half, in the fixed order that
matches ``megatron_hf_meta``. Because the bytes are HF-layout, the
sender's standard full/delta path is correct (delta in HF space).
"""
t_guard_start = _time.perf_counter()
self._wait_previous_delta()
if dist.is_initialized():
dist.barrier()
t0 = _time.perf_counter()
guard_time = t0 - t_guard_start

is_writer = self._buffer is not None and self._local_rank == 0
half_base = self._inactive_buf_idx * self._single_buffer_length
offset = 0
n_written = 0
for _name, tensor in hf_named_params:
nbytes = tensor.numel() * tensor.element_size()
if is_writer:
t_u8 = tensor.contiguous().view(-1).view(torch.uint8)
self._buffer[half_base + offset: half_base + offset + nbytes].copy_(
t_u8
)
n_written += 1
offset += nbytes
t1 = _time.perf_counter()

if dist.is_initialized():
dist.barrier()
t2 = _time.perf_counter()

ack = self._notify_buffer_ready(version)
t3 = _time.perf_counter()

metrics: dict = {}
if rank == 0:
metrics = {
"weight_transfer/offload_guard_time": guard_time,
"weight_transfer/offload_copy_time": t1 - t0,
"weight_transfer/offload_barrier_time": t2 - t1,
"weight_transfer/offload_notify_time": t3 - t2,
"weight_transfer/offload_total_time": t3 - t_guard_start,
}
if self._last_delta_metrics:
metrics.update(self._last_delta_metrics)
self._last_delta_metrics = None
print(
f"[WeightManager] offload mode=megatron_hf_export, "
f"wrote={n_written} tensors, total_bytes={offset}, "
f"guard={guard_time:.3f}s, copy={t1 - t0:.3f}s, "
f"barrier={t2 - t1:.3f}s, notify={t3 - t2:.3f}s, "
f"total={t3 - t_guard_start:.3f}s",
flush=True,
)
return metrics

# ------------------------------------------------------------------
# Copy strategies
# ------------------------------------------------------------------
Expand Down
Loading