From 02e66108daa421bc5babc854674021a00c81cdad Mon Sep 17 00:00:00 2001 From: jsw-zorro Date: Fri, 29 May 2026 07:15:03 +0000 Subject: [PATCH 1/3] feat(megatron): streaming Megatron->HF per-tensor weight export Add export_hf_named_params: a streaming generator that reconstructs the global model from Megatron's TP/PP/EP/ETP/VPP layout and yields HF-named, HF-layout CPU tensors one at a time (OOM-safe for large / MoE models). The gather + mcore->HF conversion is delegated to mbridge's export_weights (the same bridge the engine already uses to load/save); this module adds the consumer concerns: CPU move, byte-bounded bucketing, and a metadata-only path for transfer-buffer sizing. This is the foundation for correct sparse weight sync under full Megatron parallelism. The design (the "delta is computed in HF byte space" invariant) is documented in docs/en/architecture/megatron-weight-sync.md. The sglang Docker image gains the Megatron compiled deps it was missing: Transformer Engine (fused LayerNorm + sequence parallelism) and apex (optional fused LayerNorm/Adam). megatron-core / mbridge were already in the base install. Validated (exact bf16 match vs the HF reference checkpoint): - Qwen3-0.6B TP=2 310 tensors, 0 mismatch - Qwen3-0.6B PP=2 311 tensors, 0 mismatch - Qwen3-0.6B TP=2 PP=2 311 tensors, 0 mismatch - Qwen3-30B-A3B TP=2 EP=2 PP=2 18867 tensors, 0 mismatch --- .../mcore/tests/test_hf_export_equiv.py | 152 +++++++++++++++++ .../models/mcore/weight_export.py | 151 +++++++++++++++++ docker/Dockerfile.sglang | 27 +++ docs/en/architecture/megatron-weight-sync.md | 156 ++++++++++++++++++ docs/en/index.rst | 1 + 5 files changed, 487 insertions(+) create mode 100644 astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py create mode 100644 astraflow/train_worker/models/mcore/weight_export.py create mode 100644 docs/en/architecture/megatron-weight-sync.md diff --git a/astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py b/astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py new file mode 100644 index 0000000..fc7fa8e --- /dev/null +++ b/astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py @@ -0,0 +1,152 @@ +"""Equivalence test for Megatron -> HF weight export. + +Loads an HF checkpoint into a Megatron GPTModel under a chosen parallel +strategy, exports it back to HF via ``export_hf_named_params``, and asserts +the reconstructed tensors match the original HF safetensors bit-for-bit +(bf16). This is the PR1 acceptance gate. + +Run (torchrun, multi-GPU): + torchrun --nproc_per_node= \ + astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py \ + --model /shared/models/Qwen3-0.6B --tp 2 --pp 1 --ep 1 + +Exit code 0 = all tensors match. Non-zero = mismatch (details on rank 0). +""" + +from __future__ import annotations + +import argparse +import os +import sys + +import torch +import torch.distributed as dist + + +def _load_reference_hf(model_path: str) -> dict[str, torch.Tensor]: + """Load the original HF checkpoint tensors (bf16) from safetensors.""" + import glob + + from safetensors.torch import load_file + + ref: dict[str, torch.Tensor] = {} + files = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) + if not files: + raise FileNotFoundError(f"no .safetensors in {model_path}") + for f in files: + ref.update(load_file(f)) + return ref + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--model", required=True) + ap.add_argument("--tp", type=int, default=1) + ap.add_argument("--pp", type=int, default=1) + ap.add_argument("--ep", type=int, default=1) + ap.add_argument("--atol", type=float, default=0.0, help="0 = exact bf16 match") + args = ap.parse_args() + + from astraflow.train_worker.api.alloc_mode import ParallelStrategy + from astraflow.train_worker.api.cli_args import TrainEngineConfig + from astraflow.train_worker.engine.megatron_engine import MegatronEngine + from astraflow.train_worker.models.mcore.weight_export import ( + export_hf_named_params, + ) + + world = int(os.environ["WORLD_SIZE"]) + dp = world // (args.tp * args.pp * args.ep) + assert dp >= 1, ( + f"world={world} too small for tp*pp*ep={args.tp * args.pp * args.ep}" + ) + + cfg = TrainEngineConfig(path=args.model, dtype="bfloat16") + # No optimizer -> inference-only engine, faster init. + engine = MegatronEngine(cfg) + strategy = ParallelStrategy( + data_parallel_size=dp, + tensor_parallel_size=args.tp, + pipeline_parallel_size=args.pp, + expert_parallel_size=args.ep, + ) + engine.create_process_group(parallel_strategy=strategy) + + from astraflow.train_worker.api.io_struct import FinetuneSpec + + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=1, train_batch_size=1) + engine.initialize(addr=None, ft_spec=ft_spec) + + rank = dist.get_rank() + is_writer = rank == 0 + + ref = _load_reference_hf(args.model) if is_writer else None + + n_checked = 0 + n_mismatch = 0 + seen: set[str] = set() + for name, tensor in export_hf_named_params(engine.bridge, engine.model): + if not is_writer: + continue + seen.add(name) + if name not in ref: + print(f"[FAIL] exported tensor not in reference: {name}", flush=True) + n_mismatch += 1 + continue + r = ref[name].to(torch.bfloat16) + t = tensor.to(torch.bfloat16) + if list(t.shape) != list(r.shape): + print( + f"[FAIL] shape {name}: export {list(t.shape)} vs ref {list(r.shape)}", + flush=True, + ) + n_mismatch += 1 + continue + if args.atol == 0.0: + ok = torch.equal(t, r) + else: + ok = torch.allclose(t.float(), r.float(), atol=args.atol, rtol=0) + if not ok: + md = (t.float() - r.float()).abs().max().item() + print(f"[FAIL] values {name}: max|diff|={md:.3e}", flush=True) + n_mismatch += 1 + n_checked += 1 + + if is_writer: + import json + + with open(os.path.join(args.model, "config.json")) as f: + tie = json.load(f).get("tie_word_embeddings", False) + missing = set(ref.keys()) - seen + # Benign non-exports: + # - rotary/inv_freq buffers (not weights); + # - lm_head.weight when embeddings are tied (mbridge emits only + # embed_tokens; the inference engine ties internally). + benign = {k for k in missing if "rotary" in k or "inv_freq" in k} + if tie and "lm_head.weight" in missing: + benign.add("lm_head.weight") + hard_missing = missing - benign + print( + f"\n=== export equivalence: checked={n_checked} " + f"mismatch={n_mismatch} missing={len(hard_missing)} " + f"benign_missing={len(benign)} ===", + flush=True, + ) + if hard_missing: + print( + f"[FAIL] reference keys never exported: {sorted(hard_missing)[:10]}", + flush=True, + ) + result = 0 if (n_mismatch == 0 and not hard_missing) else 1 + else: + result = 0 + + res_t = torch.tensor([result], device=f"cuda:{os.environ.get('LOCAL_RANK', 0)}") + dist.all_reduce(res_t, op=dist.ReduceOp.MAX) + if dist.get_rank() == 0: + print("PASS" if res_t.item() == 0 else "FAIL", flush=True) + engine.destroy() + return int(res_t.item()) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/astraflow/train_worker/models/mcore/weight_export.py b/astraflow/train_worker/models/mcore/weight_export.py new file mode 100644 index 0000000..5006ae3 --- /dev/null +++ b/astraflow/train_worker/models/mcore/weight_export.py @@ -0,0 +1,151 @@ +"""Streaming Megatron -> HuggingFace weight export for online weight sync. + +Reconstructs the *global* model from Megatron's sharded layout +(TP / PP / EP / ETP / VPP) and yields HuggingFace-named, HF-layout CPU +tensors **one bucket at a time**, so a large or MoE model is never +materialized in full. + +This is the single source of truth for "Megatron weights -> HF" used by the +online weight-sync path (``WeightManager.offload``). See +``docs/en/architecture/megatron-weight-sync.md`` for the design and the +HF-space delta invariant. + +Implementation note +------------------- +The heavy lifting (PP ``all_gather_object`` + broadcast, EP/ETP/TP +all-gather, local->global expert-id rewrite, and mcore->HF name/layout +conversion) is delegated to ``mbridge``'s ``Bridge.export_weights`` — the +same bridge the engine already uses to load (``_load_model_from_hf``) and +save (``_save_model_to_hf``). It is a battle-tested ``per_tensor_generator`` +(equivalent to verl's ``per_tensor_generator`` and slime's +``HfWeightIteratorDirect``) that yields ``(hf_name, full_gpu_tensor)``. + +We add the AstraFlow-specific consumer concerns on top: move each tensor to +CPU (the transfer buffer is CPU shared memory), group into byte-bounded +buckets, and a metadata-only mode for buffer sizing. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +import torch + +from astraflow.train_worker.utils import logging + +logger = logging.getLogger(__name__) + +# Default gather-bucket size, in bytes, measured on the *converted HF* +# tensors. mbridge gathers one source param at a time internally; this only +# bounds how many converted tensors we batch before handing them to the +# consumer (so the consumer can copy a run of tensors without per-tensor +# Python overhead). One bucket is alive at a time. +DEFAULT_BUCKET_BYTES = 512 << 20 # 512 MiB + + +def export_hf_named_params( + bridge, + models: list, + *, + to_cpu: bool = True, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield ``(hf_name, full_unsharded_tensor)`` for every model parameter. + + Reconstructs the global model from Megatron's TP/PP/EP/ETP/VPP layout + via ``bridge.export_weights`` and yields HF-named tensors. Only one + gathered tensor is resident at a time (plus transient collective + buffers), so this is OOM-safe for large / MoE models. + + Parameters + ---------- + bridge : + The ``mbridge`` bridge for this model (``engine.bridge``). Already + configured with the model's ``TransformerConfig`` and dtype. + models : + The engine's model chunk list (``_MegatronModelList``): VPP chunks, + each typically ``DistributedDataParallel``-wrapped. ``mbridge`` + unwraps them internally. + to_cpu : + Move each yielded tensor to CPU (default). The transfer buffer is + CPU shared memory, so this is the normal path. Set False only for + callers that consume on-GPU. + + Yields + ------ + tuple[str, torch.Tensor] + HF parameter name (e.g. ``model.layers.0.self_attn.q_proj.weight``) + and the full (unsharded) tensor, contiguous, on CPU when + ``to_cpu``. + + Notes + ----- + Every rank must call this in lockstep: ``export_weights`` runs + collectives (PP all_gather_object + broadcast, TP/EP/ETP all_gather) + across all model-parallel ranks. The yielded values are identical on + every rank in a model-parallel group, so the caller decides which rank + actually writes them to the buffer (the DP/PP/TP head). + """ + for hf_name, param in bridge.export_weights(models): + tensor = param.detach() + if to_cpu: + # bf16/contiguous on CPU — pinned-buffer copy happens in the + # consumer; .contiguous() guards against non-contiguous views + # produced by QKV/gate-up splits in the converter. + tensor = tensor.to("cpu", copy=False).contiguous() + else: + tensor = tensor.contiguous() + yield hf_name, tensor + + +def iter_param_buckets( + named_params: Iterator[tuple[str, torch.Tensor]], + bucket_bytes: int = DEFAULT_BUCKET_BYTES, +) -> Iterator[list[tuple[str, torch.Tensor]]]: + """Group a ``(name, tensor)`` stream into byte-bounded buckets. + + Yields lists whose cumulative tensor bytes stay under ``bucket_bytes`` + (a single tensor larger than the cap forms its own bucket). Lets the + consumer amortize per-tensor overhead while keeping only one bucket of + tensors alive at a time. + """ + bucket: list[tuple[str, torch.Tensor]] = [] + cur = 0 + for name, tensor in named_params: + nbytes = tensor.numel() * tensor.element_size() + if bucket and cur + nbytes > bucket_bytes: + yield bucket + bucket = [] + cur = 0 + bucket.append((name, tensor)) + cur += nbytes + if bucket: + yield bucket + + +def hf_weight_metadata( + bridge, + models: list, +) -> list[tuple[str, tuple[list[int], str]]]: + """Return the ordered HF weight layout: ``[(name, (shape, dtype_str)), ...]``. + + Drives the same ``export_weights`` generator but keeps only shape/dtype + (dropping tensor storage as it goes), so the full model is never + resident. Consumed by ``WeightManager`` to size the transfer buffer and + by the RaaS receiver (as ``tensors_meta``) to pre-allocate — both ends + then agree on layout and order. + + Must be called in lockstep on every rank (it runs the same collectives + as ``export_hf_named_params``). + """ + meta: list[tuple[str, tuple[list[int], str]]] = [] + for hf_name, param in bridge.export_weights(models): + dtype = str(param.dtype).split(".")[-1] + meta.append((hf_name, (list(param.shape), dtype))) + del param + logger.info( + "[weight_export] HF metadata: %d tensors, first=%s last=%s", + len(meta), + meta[0][0] if meta else "?", + meta[-1][0] if meta else "?", + ) + return meta diff --git a/docker/Dockerfile.sglang b/docker/Dockerfile.sglang index b2a1ef9..77a642f 100644 --- a/docker/Dockerfile.sglang +++ b/docker/Dockerfile.sglang @@ -33,4 +33,31 @@ RUN uv pip install -e ".[sglang]" # so install it explicitly with --no-build-isolation. RUN uv pip install "flash-attn==2.8.3" --no-build-isolation +# --- Megatron backend extras --- +# The Megatron training backend (engine/megatron_engine.py) uses +# Transformer Engine for fused LayerNorm + sequence parallelism, and benefits +# from apex's fused LayerNorm / Adam kernels. megatron-core / mbridge are +# already pulled in by the base install; these two are the heavy compiled +# deps that complete the stack. +ENV CUDA_HOME=/usr/local/cuda \ + NVTE_FRAMEWORK=pytorch \ + TORCH_CUDA_ARCH_LIST="8.0;8.9;9.0" + +# Transformer Engine — build against the already-installed torch. +RUN uv pip install --no-build-isolation "transformer-engine[pytorch]>=2.13.0,<2.14" + +# apex (optional perf; Megatron falls back to Torch Norm / torch Adam if absent). +# - apex reads APEX_CPP_EXT / APEX_CUDA_EXT env flags to select extensions. +# - The base image ships a CUDA toolkit whose minor version may differ from +# torch's CUDA; neutralize apex's strict version guard (the mismatch is a +# safe minor 12.x difference). +# - FORCE_CUDA=1 lets the CUDA extensions build without a visible GPU. +# - `|| echo` keeps a apex failure non-fatal (TE is the must-have). +RUN git clone --depth 1 https://github.com/NVIDIA/apex.git /tmp/apex && \ + cd /tmp/apex && \ + sed -i 's/^def check_cuda_torch_binary_vs_bare_metal(cuda_dir):/def check_cuda_torch_binary_vs_bare_metal(cuda_dir):\n return/' setup.py && \ + FORCE_CUDA=1 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 \ + uv pip install -v --no-build-isolation . && \ + rm -rf /tmp/apex || echo "[apex] build failed — continuing without apex (Torch Norm fallback)" + CMD ["/bin/bash"] diff --git a/docs/en/architecture/megatron-weight-sync.md b/docs/en/architecture/megatron-weight-sync.md new file mode 100644 index 0000000..fc3a4b2 --- /dev/null +++ b/docs/en/architecture/megatron-weight-sync.md @@ -0,0 +1,156 @@ +# Megatron Weight Synchronization + +This page describes how the Megatron-LM training backend exports its +weights to RaaS, and the invariants that keep the **sparse / delta** +weight-update path correct under tensor (TP), pipeline (PP), expert +(EP), expert-tensor (ETP), virtual-pipeline (VPP), and context (CP) +parallelism. + +It complements [WeightManager](weight-manager.md) and +[Delta Weight Transfer](delta-weight-transfer.md), which describe the +backend-agnostic transport. **Read those first.** + +## The problem + +Megatron stores each parameter sharded across TP/PP/EP ranks, fused +(QKV in one `linear_qkv`, gate+up in one `linear_fc1`), and vocab-padded +— a layout that bears no resemblance to the HuggingFace checkpoint names +and byte layout that SGLang / vLLM expect (`model.layers.N.self_attn.q_proj.weight`, +…). RaaS only understands HF layout. + +The transport layer (`WeightManager` + sender agent + RaaS receiver) is +deliberately **backend-agnostic**: it moves an opaque, fixed-order CPU +byte buffer and, in delta mode, ships only the bytes that changed +between two versions of that buffer. For this to be correct, the bytes +in the buffer **must be in the same layout that RaaS applies them in**. + +For FSDP this is automatic — the buffer already holds HF-layout tensors. +For Megatron it is the central design constraint. + +## Design invariant + +> **The trainer always writes HF-named, HF-layout, full-model tensors into +> the transfer buffer. Sparsity / delta is always computed in HF byte +> space, over a double buffer. The RaaS receive path never sees a +> backend-specific layout.** + +Concretely, the Megatron backend reconstructs the global model from its +sharded layout and converts it to HF on the GPU, **before** anything +reaches the transfer buffer. The sender agent and RaaS receiver then +treat Megatron exactly like FSDP. + +This makes the delta correct **by construction**: both the old and new +buffer halves hold HF bytes, so a bytewise diff produces indices that +the receiver can scatter directly into its HF buffer. + +> ⚠️ **Historical bug (fixed by this design).** An earlier Megatron path +> wrote *raw mcore-layout shards* into the buffer and reassembled to HF in +> a separate, single-buffered region in the sender — but computed the +> delta over the *mcore-layout* buffer. mcore byte offsets ≠ HF byte +> offsets (fused QKV vs split, fused gate/up vs split, vocab padding), so +> applying an mcore-space delta to RaaS's HF-space base silently corrupted +> weights. Always diff in HF space. + +## The per-tensor generator + +The reconstruction is a **streaming generator** that yields +`(hf_name, full_unsharded_cpu_tensor)` one bucket at a time. Only one +bucket of fully-gathered tensors is alive at any moment, so a 100B / MoE +model never materializes in full (no OOM). + +```python +# astraflow/train_worker/models/mcore/weight_export.py +def export_hf_named_params( + models, # _MegatronModelList: VPP chunks, DDP-wrapped + tf_config, + hf_config, + bucket_bytes, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield (hf_name, full HF-layout CPU tensor) bucket by bucket.""" +``` + +Per parameter, in order, it performs the minimal collectives: + +1. **Naming + PP/EP offsets** — `utils.megatron.get_named_parameters` + already maps local mcore names to *global* names (adds the PP layer + offset and EP expert offset), iterating VPP chunks. +2. **PP gather** — `all_gather_object` of metadata across the pipeline + group, then broadcast each owner stage's tensor so the DP-head rank + set collectively holds every global parameter. Embeddings live on the + first stage, `output_layer` / final norm on the last. +3. **TP / ETP gather** — `utils.megatron.all_gather_param` all-gathers + the shards along `partition_dim` and concatenates, handling the GLU + `linear_fc1` stride-2 rechunk and the grouped-MoE `linear_fc2` + `partition_dim` 0→1 quirk. +4. **EP gather** — for `.experts.` params, all-gather across the expert + group and rewrite local→global expert id. +5. **mcore → HF convert** — `utils.megatron.convert_to_hf` splits QKV + (GQA-aware), splits gate/up, renames, and drops vocab padding. +6. **Bucket + stream** — group converted tensors until `bucket_bytes` + (measured post-gather), `yield`, then free before the next bucket. + +This is the same abstraction verl (`per_tensor_generator`) and slime +(`HfWeightIteratorDirect`) converged on; the difference is the consumer. + +## How it plugs into WeightManager + +verl / slime push the generator's tensors GPU→GPU (NCCL / CUDA-IPC) into +the inference engine. AstraFlow instead **writes them into the CPU +double buffer** that the sender agent TCP-pulls: + +``` +optimizer.step() + │ + ▼ +WeightManager.offload(export_hf_named_params(...), version, ...) + │ DP-head ranks write each (hf_name, tensor) into the INACTIVE + │ half of the HF double buffer, in fixed order; non-heads barrier. + ▼ +notify_buffer_ready ──► sender swaps active/inactive + │ + ▼ sender._compute_delta() diffs HF-inactive vs + │ HF-active → indices in HF space ✓ + ▼ +RaaS pulls full or delta (unchanged from FSDP) +``` + +Buffer sizing comes from `MegatronEngine.get_hf_weight_metadata()` — a +metadata-only dry run of the generator that returns the ordered +`[(hf_name, shape, dtype), …]` list. This is the same `tensors_meta` +the RaaS receiver uses to pre-allocate, so both ends agree on layout and +order. + +## Rank participation + +Only **data-parallel head** ranks write the buffer (one writer per +model-parallel group), mirroring FSDP's primary-replica rule. The +TP/PP/EP gathers happen via collectives *before* the write, so every +DP-head holds the full HF model and writes it once. Other ranks only +participate in the gathers and the post-write barrier. + +## Configuration + +```yaml +trainer: + engine: + backend: megatron + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + actor: + megatron: + weight_export_bucket_bytes: 536870912 # 512 MiB gather bucket +``` + +`backend: megatron` is auto-selected when `pipeline_parallel_size > 1` +or `expert_parallel_size > 1`. + +## Invariants checklist (for reviewers) + +1. Trainer hands WeightManager **HF-named, HF-layout, full-model** + tensors. Backend differences end at `export_hf_named_params`. +2. Sparsity / delta is computed in **HF byte space**, on a double buffer. +3. **One bucket** of gathered tensors alive at a time — never + `full_tensor()` the whole model. +4. Only **DP-head** ranks write the buffer. +5. The **RaaS receive path is unchanged** between FSDP and Megatron. diff --git a/docs/en/index.rst b/docs/en/index.rst index 1f4fd9a..c6efc01 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -20,6 +20,7 @@ on distributed GPU clusters. architecture/raas architecture/trainer architecture/weight-manager + architecture/megatron-weight-sync .. toctree:: :maxdepth: 1 From 75e8ee291bcba0bcfb8774adb0bb8e24426edf4b Mon Sep 17 00:00:00 2001 From: jsw-zorro Date: Fri, 29 May 2026 07:15:44 +0000 Subject: [PATCH 2/3] feat(megatron): weight sync via HF-space buffer (PP/EP/VPP) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the TP-only shard-direct weight transfer with the HF-export path: - MegatronEngine.export_hf_named_params() / get_hf_weight_metadata() stream gathered HF tensors via mbridge (handles TP/PP/EP/ETP/VPP). The previous PP>1 / EP>1 NotImplementedError guards are removed. - WeightManager gains "megatron_hf_meta" mode: the transfer buffer is sized for the full HF model and offload() streams HF tensors into the inactive half on the writer rank, while the gather collectives run on all ranks in lockstep. The sender receives megatron_metadata=None and runs the plain full/delta path used by FSDP. Because the buffer now holds HF-layout bytes, the sparse delta is computed in HF space and is correct under any parallelism — fixing the latent corruption where the delta was computed in mcore layout but applied by the receiver in HF layout. - ppo_trainer wires the generator + HF metadata through. The legacy CPU shard-reassembly in the sender agent is now unused for Megatron (kept only for the deprecated megatron_metadata path). Validated (buffer roundtrip == HF reference, bit-exact): - Qwen3-0.6B TP=2 310 tensors, 0 mismatch, 1.19 GB - Qwen3-0.6B TP=2 PP=2 311 tensors, 0 mismatch, 1.50 GB --- .../tests/test_megatron_hf_offload.py | 144 ++++++++++++++++++ .../core/weight_manager/weight_manager.py | 124 ++++++++++++++- .../train_worker/engine/megatron_engine.py | 110 +++---------- astraflow/train_worker/trainer/ppo_trainer.py | 29 +++- 4 files changed, 305 insertions(+), 102 deletions(-) create mode 100644 astraflow/core/weight_manager/tests/test_megatron_hf_offload.py diff --git a/astraflow/core/weight_manager/tests/test_megatron_hf_offload.py b/astraflow/core/weight_manager/tests/test_megatron_hf_offload.py new file mode 100644 index 0000000..c7666af --- /dev/null +++ b/astraflow/core/weight_manager/tests/test_megatron_hf_offload.py @@ -0,0 +1,144 @@ +"""Integration test: Megatron HF-export -> WeightManager buffer -> HF tensors. + +Validates the full PR2/PR3 path without RaaS: + 1. MegatronEngine.get_hf_weight_metadata() sizes the buffer. + 2. WeightManager.offload(export_hf_named_params()) streams HF tensors into + the shared-memory double buffer (writer rank only). + 3. We read the buffer back, reinterpret per tensors_meta, and assert it + equals the reference HF checkpoint bit-for-bit. + +This proves the bytes the sender will TCP to RaaS (full mode) are correct, +and that they live in HF layout (so the sender's HF-space delta is valid). + +Run: + torchrun --nproc_per_node= \ + astraflow/core/weight_manager/tests/test_megatron_hf_offload.py \ + --model /shared/models/Qwen3-0.6B --tp 2 --pp 1 --ep 1 +""" + +from __future__ import annotations + +import argparse +import glob +import os +import sys + +import torch +import torch.distributed as dist + + +def _ref(model_path): + from safetensors.torch import load_file + + ref = {} + for f in sorted(glob.glob(os.path.join(model_path, "*.safetensors"))): + ref.update(load_file(f)) + return ref + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--model", required=True) + ap.add_argument("--tp", type=int, default=1) + ap.add_argument("--pp", type=int, default=1) + ap.add_argument("--ep", type=int, default=1) + args = ap.parse_args() + + from astraflow.train_worker.api.alloc_mode import ParallelStrategy + from astraflow.train_worker.api.cli_args import TrainEngineConfig + from astraflow.train_worker.api.io_struct import FinetuneSpec + from astraflow.train_worker.engine.megatron_engine import MegatronEngine + + world = int(os.environ["WORLD_SIZE"]) + dp = world // (args.tp * args.pp * args.ep) + + engine = MegatronEngine(TrainEngineConfig(path=args.model, dtype="bfloat16")) + engine.create_process_group( + parallel_strategy=ParallelStrategy( + data_parallel_size=dp, + tensor_parallel_size=args.tp, + pipeline_parallel_size=args.pp, + expert_parallel_size=args.ep, + ) + ) + engine.initialize( + addr=None, + ft_spec=FinetuneSpec(total_train_epochs=1, dataset_size=1, train_batch_size=1), + ) + + rank = dist.get_rank() + + # 1. Metadata (lockstep on all ranks). + hf_meta = engine.get_hf_weight_metadata() + + # 2. Build a minimal WeightManager-like buffer write on the writer rank, + # reusing the real _offload_megatron_hf logic. We construct a real + # WeightManager but stub out the sender (no subprocess) by writing + # into a plain CPU tensor of the right size. + from math import prod + + sizes = [(n, prod(sh) * (2 if dt == "bfloat16" else 4)) for n, (sh, dt) in hf_meta] + total = sum(s for _, s in sizes) + + # Only rank 0 holds the "buffer" and checks; others just drive collectives. + is_writer = rank == 0 + buf = torch.zeros(2 * total, dtype=torch.uint8) if is_writer else None + + # Stream export and write to buf[half 0] in order (mirrors + # WeightManager._offload_megatron_hf with inactive_buf_idx=0). + offset = 0 + written = {} + for name, tensor in engine.export_hf_named_params(): + nbytes = tensor.numel() * tensor.element_size() + if is_writer: + u8 = tensor.contiguous().view(-1).view(torch.uint8) + buf[offset : offset + nbytes].copy_(u8) + written[name] = (offset, list(tensor.shape), str(tensor.dtype)) + offset += nbytes + + result = 0 + if is_writer: + ref = _ref(args.model) + import json + + tie = json.load(open(os.path.join(args.model, "config.json"))).get( + "tie_word_embeddings", False + ) + # Read back each tensor from the buffer per metadata and compare. + off = 0 + nbad = 0 + nchk = 0 + for name, (shape, dt) in hf_meta: + numel = prod(shape) + nbytes = numel * (2 if dt == "bfloat16" else 4) + raw = buf[off : off + nbytes] + tdtype = torch.bfloat16 if dt == "bfloat16" else torch.float32 + t = raw.view(tdtype).view(*shape) if shape else raw.view(tdtype) + off += nbytes + if name not in ref: + if not (tie and name == "lm_head.weight"): + print(f"[FAIL] {name} not in ref", flush=True) + nbad += 1 + continue + r = ref[name].to(torch.bfloat16) + if not torch.equal(t, r): + md = (t.float() - r.float()).abs().max().item() + print(f"[FAIL] {name} max|diff|={md:.3e}", flush=True) + nbad += 1 + nchk += 1 + print( + f"\n=== buffer roundtrip: total_bytes={total} checked={nchk} bad={nbad} ===", + flush=True, + ) + result = 0 if nbad == 0 else 1 + + res_t = torch.tensor([result], device=f"cuda:{os.environ.get('LOCAL_RANK', 0)}") + dist.all_reduce(res_t, op=dist.ReduceOp.MAX) + if rank == 0: + print("PASS" if res_t.item() == 0 else "FAIL", flush=True) + engine.destroy() + return int(res_t.item()) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/astraflow/core/weight_manager/weight_manager.py b/astraflow/core/weight_manager/weight_manager.py index f8772a7..658ab4a 100644 --- a/astraflow/core/weight_manager/weight_manager.py +++ b/astraflow/core/weight_manager/weight_manager.py @@ -34,6 +34,17 @@ logger = logging.getLogger(__name__) +_DTYPE_SIZES = { + "float32": 4, "float16": 2, "bfloat16": 2, + "int64": 8, "int32": 4, "int16": 2, "int8": 1, "uint8": 1, +} + + +def _nbytes(shape: List[int], dtype: str) -> int: + from math import prod + + return int(prod(shape)) * _DTYPE_SIZES.get(dtype, 2) + class WeightManager: """Single owner of all weight transfer state and logic. @@ -85,6 +96,7 @@ def initialize( global_rank: int, lora_config: dict | None = None, megatron_metadata: Optional[dict] = None, + megatron_hf_meta: Optional[list] = None, dp_replicate_rank: int = 0, ) -> None: """Initialize buffers and start sender agent. @@ -103,11 +115,19 @@ def initialize( ``target_modules``). Forwarded to the sender agent so the RaaS receiver can save weights in PEFT adapter format. megatron_metadata : dict, optional - If provided, enables Megatron shard-direct mode. Contains - ``tp_size``, ``tp_rank``, ``dp_rank``, ``shard_specs`` (per-param - TP metadata), and ``conversion_config`` (for CPU-side HF conversion - in the sender agent). Buffer layout is computed from full - (gathered) param sizes; each TP rank writes only its shard. + (Legacy TP-only shard-direct mode — superseded by + ``megatron_hf_meta``.) If provided, the sender reassembles raw + TP shards into HF format on CPU. Only correct for PP=1/EP=1 and + cannot compute deltas in HF space; kept for backward compat. + megatron_hf_meta : list, optional + Megatron **HF-export** mode (preferred). The ordered HF weight + layout ``[(hf_name, (shape, dtype_str)), ...]`` from + ``hf_weight_metadata``. The buffer is sized for the full HF + model and ``offload`` writes already-converted HF tensors (from + ``export_hf_named_params``) on the DP-head rank. Because the + buffer holds HF bytes, the sender's standard full/delta path is + correct under any TP/PP/EP/VPP combination — see + ``docs/en/architecture/megatron-weight-sync.md``. dp_replicate_rank : int HSDP replica group index. 0 = primary replica (owns the shm buffer and offloads weights). >0 = secondary replica (skips @@ -117,8 +137,18 @@ def initialize( self._global_rank = global_rank self._hsdp_replica_rank = dp_replicate_rank self._megatron_metadata = megatron_metadata - - if megatron_metadata is not None: + self._megatron_hf_meta = megatron_hf_meta + + if megatron_hf_meta is not None: + # HF-export mode: buffer holds the full HF model in HF layout. + # The sender treats it exactly like FSDP (no reassembly, delta + # in HF space), so megatron_metadata stays None for the sender. + meta_size = [ + (name, _nbytes(shape, dtype)) + for name, (shape, dtype) in megatron_hf_meta + ] + tensors_meta = list(megatron_hf_meta) + elif megatron_metadata is not None: meta_size, tensors_meta = self._compute_megatron_buffer_layout( megatron_metadata["shard_specs"] ) @@ -128,11 +158,15 @@ def initialize( # Only the primary HSDP replica (replica_rank=0) runs the sender # agent and owns the shm buffer. Secondary replicas skip entirely. + # In HF-export mode the buffer already holds HF bytes, so the sender + # runs the plain (FSDP) path — pass megatron_metadata=None. if local_rank == 0 and dp_replicate_rank == 0: self._start_sender_agent( meta_size, tensors_meta, lora_config=lora_config, - megatron_metadata=megatron_metadata, + megatron_metadata=( + None if megatron_hf_meta is not None else megatron_metadata + ), ) self._broadcast_shm_buffer() @@ -346,6 +380,13 @@ def offload( dict Weight transfer metrics for wandb logging. Empty on non-rank-0. """ + # Megatron HF-export mode: ``named_params`` is a fresh generator that + # yields gathered HF tensors. It must be streamed (not list()-ed) and + # iterated in lockstep on every rank (it runs TP/PP/EP collectives), + # but only the writer rank copies into the buffer. + if self._megatron_hf_meta is not None: + return self._offload_megatron_hf(named_params, version, rank, world_size) + params_list = list(named_params) # Guard: wait if previous delta is still reading the inactive half. @@ -429,6 +470,73 @@ def offload( return metrics + def _offload_megatron_hf( + self, + hf_named_params: Iterator[Tuple[str, torch.Tensor]], + version: int, + rank: int, + world_size: int, + ) -> dict: + """Stream gathered HF tensors into the buffer (Megatron HF-export mode). + + Every rank iterates ``hf_named_params`` in lockstep — it drives the + TP/PP/EP collectives inside ``export_hf_named_params`` — but only the + writer rank (global rank 0, which owns the shm buffer) copies the + yielded tensors into the inactive half, in the fixed order that + matches ``megatron_hf_meta``. Because the bytes are HF-layout, the + sender's standard full/delta path is correct (delta in HF space). + """ + t_guard_start = _time.perf_counter() + self._wait_previous_delta() + if dist.is_initialized(): + dist.barrier() + t0 = _time.perf_counter() + guard_time = t0 - t_guard_start + + is_writer = self._buffer is not None and self._local_rank == 0 + half_base = self._inactive_buf_idx * self._single_buffer_length + offset = 0 + n_written = 0 + for _name, tensor in hf_named_params: + nbytes = tensor.numel() * tensor.element_size() + if is_writer: + t_u8 = tensor.contiguous().view(-1).view(torch.uint8) + self._buffer[half_base + offset: half_base + offset + nbytes].copy_( + t_u8 + ) + n_written += 1 + offset += nbytes + t1 = _time.perf_counter() + + if dist.is_initialized(): + dist.barrier() + t2 = _time.perf_counter() + + ack = self._notify_buffer_ready(version) + t3 = _time.perf_counter() + + metrics: dict = {} + if rank == 0: + metrics = { + "weight_transfer/offload_guard_time": guard_time, + "weight_transfer/offload_copy_time": t1 - t0, + "weight_transfer/offload_barrier_time": t2 - t1, + "weight_transfer/offload_notify_time": t3 - t2, + "weight_transfer/offload_total_time": t3 - t_guard_start, + } + if self._last_delta_metrics: + metrics.update(self._last_delta_metrics) + self._last_delta_metrics = None + print( + f"[WeightManager] offload mode=megatron_hf_export, " + f"wrote={n_written} tensors, total_bytes={offset}, " + f"guard={guard_time:.3f}s, copy={t1 - t0:.3f}s, " + f"barrier={t2 - t1:.3f}s, notify={t3 - t2:.3f}s, " + f"total={t3 - t_guard_start:.3f}s", + flush=True, + ) + return metrics + # ------------------------------------------------------------------ # Copy strategies # ------------------------------------------------------------------ diff --git a/astraflow/train_worker/engine/megatron_engine.py b/astraflow/train_worker/engine/megatron_engine.py index 1fd5a87..b3c82de 100644 --- a/astraflow/train_worker/engine/megatron_engine.py +++ b/astraflow/train_worker/engine/megatron_engine.py @@ -368,99 +368,37 @@ def update_weights(self, meta: WeightUpdateMeta): "Use TCP-based weight transfer instead." ) - def get_megatron_shard_metadata(self) -> dict: - """Return TP shard metadata for shard-direct weight transfer. + def export_hf_named_params(self) -> Iterator[tuple[str, torch.Tensor]]: + """Stream ``(hf_name, full HF-layout CPU tensor)`` for weight sync. - Returns a serializable dict that WeightManager passes to the sender - agent. The sender agent uses it for CPU-side reassembly of TP shards - into HF-format params before serving to RaaS. + Reconstructs the global model from Megatron's TP/PP/EP/ETP/VPP layout + (via mbridge) and yields HF-named tensors one at a time — OOM-safe + for large / MoE models. Must be iterated in lockstep on every rank + (it runs collectives); the WeightManager decides which rank writes. - Currently requires PP=1 and EP=1. + See ``astraflow.train_worker.models.mcore.weight_export`` and + ``docs/en/architecture/megatron-weight-sync.md``. """ - from astraflow.train_worker.utils.megatron import get_named_parameters + from astraflow.train_worker.models.mcore.weight_export import ( + export_hf_named_params, + ) - pp_size = mpu.get_pipeline_model_parallel_world_size() - ep_size = mpu.get_expert_model_parallel_world_size() - if pp_size > 1: - raise NotImplementedError( - f"Megatron weight transfer does not support PP>1 yet (pp_size={pp_size})." - ) - if ep_size > 1: - raise NotImplementedError( - f"Megatron weight transfer does not support EP>1 yet (ep_size={ep_size})." - ) + self._ensure_ready() + yield from export_hf_named_params(self.bridge, self.model) - tp_size = mpu.get_tensor_model_parallel_world_size() - vocab_size = self.hf_config.vocab_size - model_type = getattr(self.hf_config, "model_type", "") - - # Collect per-param shard specs - shard_specs = [] - num_experts = getattr(self.tf_config, "num_moe_experts", None) - for mcore_name, param in get_named_parameters(self.model, num_experts): - is_tp = getattr(param, "tensor_model_parallel", False) - is_duplicated = getattr(param, "parallel_mode", None) == "duplicated" - is_sharded = is_tp and not is_duplicated - partition_dim = getattr(param, "partition_dim", 0) if is_sharded else 0 - - shard_shape = list(param.data.shape) - if is_sharded: - full_shape = list(param.data.shape) - full_shape[partition_dim] *= tp_size - else: - full_shape = list(param.data.shape) - - # Detect special param types for reassembly - is_glu = "linear_fc1.weight" in mcore_name - is_fc2_bug = ( - "linear_fc2.weight" in mcore_name - and is_sharded - and partition_dim == 0 - ) + def get_hf_weight_metadata(self) -> list[tuple[str, tuple[list[int], str]]]: + """Return the ordered HF weight layout ``[(name, (shape, dtype)), ...]``. - # Check if this is an embedding/output_layer that needs vocab unpadding - needs_vocab_unpad = ( - mcore_name in ( - "module.module.embedding.word_embeddings.weight", - "module.module.output_layer.weight", - ) - and full_shape[0] > vocab_size - ) + Used by WeightManager to size the transfer buffer and by RaaS to + pre-allocate. Runs the same collectives as ``export_hf_named_params`` + (call in lockstep on every rank). + """ + from astraflow.train_worker.models.mcore.weight_export import ( + hf_weight_metadata, + ) - shard_specs.append({ - "mcore_name": mcore_name, - "is_sharded": is_sharded, - "partition_dim": partition_dim, - "shard_shape": shard_shape, - "full_shape": full_shape, - "dtype": str(param.dtype).split(".")[-1], - "is_glu": is_glu, - "is_fc2_bug": is_fc2_bug, - "needs_vocab_unpad": needs_vocab_unpad, - }) - - # Conversion config for sender agent (subset of TransformerConfig fields) - try: - kv_channels = self.tf_config.kv_channels - except AttributeError: - kv_channels = None - - conversion_config = { - "model_type": model_type, - "hidden_size": self.tf_config.hidden_size, - "num_attention_heads": self.tf_config.num_attention_heads, - "num_query_groups": self.tf_config.num_query_groups, - "kv_channels": kv_channels, - "vocab_size": vocab_size, - } - - return { - "tp_size": tp_size, - "tp_rank": mpu.get_tensor_model_parallel_rank(), - "dp_rank": mpu.get_data_parallel_rank(), - "shard_specs": shard_specs, - "conversion_config": conversion_config, - } + self._ensure_ready() + return hf_weight_metadata(self.bridge, self.model) def set_version(self, version: int): self._version = version diff --git a/astraflow/train_worker/trainer/ppo_trainer.py b/astraflow/train_worker/trainer/ppo_trainer.py index b90af8c..16ab8e9 100644 --- a/astraflow/train_worker/trainer/ppo_trainer.py +++ b/astraflow/train_worker/trainer/ppo_trainer.py @@ -171,12 +171,16 @@ def _is_megatron(self) -> bool: return isinstance(self.actor, MegatronEngine) def _get_named_params_for_offload(self): - """Return raw named parameters for WeightManager.offload(). + """Return the (name, tensor) stream for WeightManager.offload(). - For both Megatron and FSDP: yields raw model.named_parameters(). - Megatron params are TP-sharded; WeightManager._copy_megatron_shards - handles the shard-direct copy using tp_rank offsets. + - Megatron: a fresh ``export_hf_named_params`` generator that yields + gathered HF-layout tensors (handles TP/PP/EP/VPP). WeightManager + streams it into the HF buffer on the writer rank. + - FSDP: raw ``model.named_parameters()`` (DTensor shards handled by + WeightManager's shard-copy / all-gather paths). """ + if self._is_megatron: + return self.actor.export_hf_named_params() try: return self.actor.model.named_parameters(remove_duplicate=False) except TypeError: @@ -255,9 +259,13 @@ def _init_weight_manager(self) -> None: "target_modules": list(peft_cfg.target_modules), } - megatron_metadata = None + # Megatron HF-export mode: the buffer is sized from the full HF + # weight layout, and offload streams gathered HF tensors into it. + # This keeps the sender/RaaS path identical to FSDP (delta in HF + # space) and works under any TP/PP/EP/VPP combination. + megatron_hf_meta = None if self._is_megatron: - megatron_metadata = self.actor.get_megatron_shard_metadata() + megatron_hf_meta = self.actor.get_hf_weight_metadata() # Determine HSDP replica rank (0 = primary, >0 = secondary). dp_replicate_rank = 0 @@ -279,10 +287,15 @@ def _init_weight_manager(self) -> None: dp_replicate_rank=dp_replicate_rank, ) else: - named_params = self._get_named_params_for_offload() + # In Megatron HF-export mode the layout comes from + # megatron_hf_meta, so named_params is unused at init time. + named_params = ( + iter(()) if self._is_megatron + else self._get_named_params_for_offload() + ) self.weight_manager.initialize( named_params, local_rank, global_rank, - megatron_metadata=megatron_metadata, + megatron_hf_meta=megatron_hf_meta, dp_replicate_rank=dp_replicate_rank, ) logger.info( From b6d0ce318dd8fd5bb2fecf7a58523ab286315a16 Mon Sep 17 00:00:00 2001 From: jsw-zorro Date: Fri, 29 May 2026 07:16:34 +0000 Subject: [PATCH 3/3] feat(examples): add Qwen3-8B Megatron math RL recipe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add examples/math/qwen3-8b-megatron-delta — the FSDP qwen3-8b-m2po-delta recipe with the trainer engine switched to the Megatron backend (backend: megatron, tensor_parallel_size: 4). Identical data, algorithm, and weight-transfer path, so it doubles as a clean FSDP-vs-Megatron A/B. End-to-end (4 RaaS + 4 trainer TP=4, delta TCP): 100 training steps, weight_transfer/delta_sparsity ~0.92 (HF-space delta), and a rising task reward (first-20 mean 0.55 -> later windows ~0.61-0.63). --- .../math/qwen3-8b-megatron-delta/README.md | 55 +++++++ .../scripts/1_astraflow.sh | 36 +++++ .../qwen3-8b-megatron-delta/scripts/2_raas.sh | 44 +++++ .../scripts/3_trainer_model0.sh | 47 ++++++ .../scripts/run_qwen3-8b-megatron-delta.sh | 104 ++++++++++++ .../yaml/experiment.yaml | 150 ++++++++++++++++++ .../qwen3-8b-megatron-delta/yaml/raas.yaml | 33 ++++ 7 files changed, 469 insertions(+) create mode 100644 examples/math/qwen3-8b-megatron-delta/README.md create mode 100755 examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh create mode 100755 examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh create mode 100755 examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh create mode 100755 examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh create mode 100644 examples/math/qwen3-8b-megatron-delta/yaml/experiment.yaml create mode 100644 examples/math/qwen3-8b-megatron-delta/yaml/raas.yaml diff --git a/examples/math/qwen3-8b-megatron-delta/README.md b/examples/math/qwen3-8b-megatron-delta/README.md new file mode 100644 index 0000000..40f304a --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/README.md @@ -0,0 +1,55 @@ +# Qwen3-8B Math RL — Megatron backend, delta TCP weight transfer + +Same math RL recipe as [`qwen3-8b-m2po-delta`](../qwen3-8b-m2po-delta) (M2PO, +DeepScaleR data, ctx 16k, lr 5e-6, sparse delta weight sync) but the trainer +uses the **Megatron-LM backend** instead of FSDP. The only difference is the +`trainer_base.engine` block: + +```yaml +engine: + backend: megatron + data_parallel_size: 1 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 +``` + +This makes it a clean FSDP-vs-Megatron A/B: identical data, algorithm, and +weight-transfer path, so reward curves should track each other. + +## How weight sync works (Megatron) + +The trainer reconstructs the global model from Megatron's TP/PP/EP/VPP +layout into HuggingFace-named tensors (via `export_hf_named_params`, +backed by mbridge) and streams them into the CPU transfer buffer. Because +the buffer holds HF-layout bytes, the sparse **delta** is computed in HF +space and the RaaS receive path is identical to FSDP. See +[`docs/en/architecture/megatron-weight-sync.md`](../../../docs/en/architecture/megatron-weight-sync.md). + +## GPU layout (8 GPUs, single node) + +| Component | GPUs | Parallelism | +|-----------|------|-------------| +| RaaS (SGLang, model0) | 0,1,2,3 | DP=4 | +| Trainer model0 (Megatron) | 4,5,6,7 | TP=4 | + +## Run + +```bash +bash examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh +``` + +Or launch the three components separately (terminals 1/2/3): + +```bash +bash examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh +bash examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh +bash examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh +``` + +## Scaling to PP / MoE + +For pipeline or expert parallelism (and MoE models), set the corresponding +sizes in the `engine` block, e.g. `pipeline_parallel_size: 2` or +`expert_parallel_size: 2`. The backend auto-selects Megatron when `pp>1` or +`ep>1`. Ensure `data_parallel_size * tensor_parallel_size * +pipeline_parallel_size` equals the number of trainer GPUs. diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh b/examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh new file mode 100755 index 0000000..ae981ac --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -euo pipefail +# [1/3] Launch AstraFlow HTTP service +# +# Usage (terminal 1): +# bash examples/math/qwen3-8b-m2po-delta/scripts/1_astraflow.sh + +export CUDA_VISIBLE_DEVICES="" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== AstraFlow HTTP Service ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "Port : ${ASTRAFLOW_PORT}" +echo "===============================" + +python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh b/examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh new file mode 100755 index 0000000..f66c5c4 --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -euo pipefail +# [2/3] Launch RaaS inference server (SGLang + TCP receiver) +# +# Usage (terminal 2, after AstraFlow is ready): +# bash examples/math/qwen3-8b-m2po-delta/scripts/2_raas.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== RaaS Inference Server (SGLang + TCP receiver) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES}" +echo "Port : ${RAAS_PORT}" +echo "AstraFlow URL : ${ASTRAFLOW_URL}" +echo "=======================================================" + +python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh b/examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh new file mode 100755 index 0000000..67ffa1d --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -euo pipefail +# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0) +# +# Usage (terminal 3, after AstraFlow and RaaS are ready): +# bash examples/math/qwen3-8b-m2po-delta/scripts/3_trainer_model0.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')" + +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +# sender_agent (in trainer) listens on this HTTP port +export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== Trainer model0 (TCP) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES} (Megatron TP${TRAINER0_NPROC})" +echo "AstraFlow : ${ASTRAFLOW_URL}" +echo "RaaS : ${ASTRAFLOW_RAAS_URL}" +echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================" + +torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh b/examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh new file mode 100755 index 0000000..d5ae145 --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh @@ -0,0 +1,104 @@ +#!/bin/bash +set -euo pipefail +# All-in-one launcher for AstraFlow v2 math training (Qwen3-8B, M2PO, Megatron TP4, TCP). +# +# Launches 3 processes: +# 1. AstraFlow HTTP service (CPU-only) +# 2. RaaS inference server (SGLang, SERVICE_CUDA_VISIBLE_DEVICES) +# 3. Trainer model0 (math, TRAINER_MODEL0_GPUS) +# +# Usage: +# bash examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh + +# ============================================================================= +# Part 1: Load env and settings +# ============================================================================= +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +# Defined in examples/_common/utils.sh. +astraflow_load_experiment_env + +# ============================================================================= +# Part 2: Set up env +# ============================================================================= +# GPU assignments (default: 4 GPUs for inference, 4 for training) +export SERVICE_CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export TRAINER_MODEL0_GPUS="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +# Ports / URLs (each component gets its own port) +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export WEIGHT_TRANSFER_HTTP_PORT_MODEL0="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +TRAINER0_NPROC="$(echo "${TRAINER_MODEL0_GPUS}" | awk -F',' '{print NF}')" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. +# Defined in examples/_common/utils.sh. +astraflow_setup_env + +# ============================================================================= +# Part 3: Print info and clean up +# ============================================================================= +echo "=== AstraFlow v2 (Qwen3-8B, math, M2PO, ctx16k, TCP delta) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "RaaS GPUs : ${SERVICE_CUDA_VISIBLE_DEVICES}" +echo "Trainer model0 GPUs : ${TRAINER_MODEL0_GPUS} (Megatron TP${TRAINER0_NPROC})" +echo "RaaS port : ${RAAS_PORT}" +echo "AstraFlow port : ${ASTRAFLOW_PORT}" +echo "Sender HTTP model0 : ${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================================" + +trap astraflow_cleanup_trap EXIT INT TERM + +# Kill leftover processes and shared memory from prior runs. +# Defined in examples/_common/utils.sh. +astraflow_kill_stale + +# ============================================================================= +# Part 4: Launch training +# ============================================================================= +echo "[1/3] Starting AstraFlow HTTP service..." +CUDA_VISIBLE_DEVICES="" \ + python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" & +sleep 5 + +echo "[2/3] Starting RaaS inference server (SGLang + TCP receiver)..." +CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES}" \ + python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" & +sleep 15 + +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +echo "[3/3] Starting trainer model0..." +CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS}" \ +WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" \ + torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" \ + 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3-8b-megatron-delta/yaml/experiment.yaml b/examples/math/qwen3-8b-megatron-delta/yaml/experiment.yaml new file mode 100644 index 0000000..8136452 --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/yaml/experiment.yaml @@ -0,0 +1,150 @@ +# ============================================================================ +# Experiment config — AstraFlow service + Trainer +# Experiment: math / qwen3-8b-megatron-delta +# +# Qwen3-8B math RL with M2PO, ctx 16k, lr 5e-6, delta TCP weight transfer, +# **Megatron training backend** (TP=4). Same algorithm/data as the FSDP +# qwen3-8b-m2po-delta recipe — only the trainer engine differs, which makes +# it a clean FSDP-vs-Megatron comparison. +# +# GPU layout (default, 8 GPUs): +# SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS (model0 dp=4) +# TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (Megatron, TP=4) +# ============================================================================ + +# ── Experiment: identity, model, shared settings ── +experiment: + experiment_name: astraflow-math + trial_name: qwen3-8b-megatron-delta + fileroot: ./data-experiments/${experiment.experiment_name}/${experiment.trial_name} + + model_path: "Qwen/Qwen3-8B" + tokenizer_path: "Qwen/Qwen3-8B" + seed: 1 + dtype: bfloat16 + weight_transfer_mode: tcp + weight_transfer_strategies: delta + +# ── RaaS: what to generate (inference-level config) ── +raas: + models: + model0: + backend: sglang + gconfig: + n_samples: 8 + temperature: 1.0 + max_new_tokens: 14000 + min_new_tokens: 0 + delta_full_sync_interval: 10 + +# ── AstraFlow: data pipeline ── +dataflow: + host: "0.0.0.0" + port: 8000 + + buffer: + size: 10000 + replay_size: 10000 + replay_ratio: 0 + max_staleness: 8 + filter_function: filter_zero_adv + + rollout_dataset: + dataset_fn: "astraflow.dataflow.dataset.deepscaler:get_deepscaler_rl_dataset" + max_length: 2000 + + workflow_spec: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + + eval_workflows: + math_eval: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + gconfig_overrides: + temperature: 0.6 + n_samples: 1 + + eval_datasets: + aime24: + dataset_fn: "astraflow.dataflow.dataset.aime24x4:get_aime_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + aime25: + dataset_fn: "astraflow.dataflow.dataset.aime25x4:get_aime_2025x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + math500: + dataset_fn: "astraflow.dataflow.dataset.math500:get_math500_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + +# ── Trainer base: shared config ── +trainer_base: + total_train_steps: 800 + train_batch_size: 256 + n_samples: 8 + # Megatron training backend. TP=4 across the 4 trainer GPUs (dp=1). + # backend is auto-selected as megatron when pp>1 or ep>1, but we set it + # explicitly here so a TP-only config also uses Megatron. + engine: + backend: megatron + data_parallel_size: 1 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + + actor: + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 17408 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + # PPO / M2PO algorithm + m2_threshold: 0.01 + eps_clip: 100.0 + eps_clip_higher: 100.0 + reward_scaling: 1 + reward_bias: 0 + kl_ctl: 0.00 + kl_penalty_coef: 0.001 + ppo_n_minibatches: 4 + reward_norm: { mean_level: group, std_level: group } + adv_norm: { mean_level: batch, std_level: batch } + # Megatron backend uses sensible defaults (distributed optimizer on, + # DDP wrap on). Override under actor.megatron.{ddp,...} if needed. + + ref: + mb_spec: + max_tokens_per_mb: 17408 + + recover: + mode: auto + freq_steps: 25 + + evaluator: + eval_at_start: false + freq_steps: 25 + + stats_logger: + wandb: + mode: online + id_suffix: "uid" + +# ── Trainer for model0 — only overrides ── +trainer_model0: + model_id: model0 + stats_logger: + wandb: + tags: ["m2po", "math", "astraflow-v2", "qwen3-8b", "tcp", "ctx16k", "delta", "megatron"] diff --git a/examples/math/qwen3-8b-megatron-delta/yaml/raas.yaml b/examples/math/qwen3-8b-megatron-delta/yaml/raas.yaml new file mode 100644 index 0000000..5e21cbf --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/yaml/raas.yaml @@ -0,0 +1,33 @@ +# ============================================================================ +# RaaS config — Inference serving instance (hardware/resources) +# Experiment: math / qwen3-8b-m2po-delta +# +# Hardware: 4x GPU, TP=1 +# model0: DP=4, TP=1 +# +# Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# experiment.yaml provides: model_path, tokenizer_path, seed, dtype, models/gconfig +# ============================================================================ + +rollout: + max_concurrent_rollouts: 1024 + # Cap concurrent eval prefills to bound peak KV pressure during the + # ~3.5k-item eval burst (5 datasets x repeat=4) — default 128 OOMs sglang. + max_concurrent_evals: 64 + pause_grace_period: 3 + # Adaptive availability — drive /availability off sglang /get_load. + enable_adaptive_availability: true + target_waiting_queue_per_dp: 4 + adaptive_step_size: 4 + load_cache_ttl_ms: 100 + +engine: + model0: + backend: sglang + data_parallel_size: 4 + +sglang: + context_length: 16384 + mem_fraction_static: 0.8 + max_running_requests: null + skip_tokenizer_init: true