diff --git a/astraflow/core/weight_manager/tests/test_megatron_hf_offload.py b/astraflow/core/weight_manager/tests/test_megatron_hf_offload.py new file mode 100644 index 0000000..c7666af --- /dev/null +++ b/astraflow/core/weight_manager/tests/test_megatron_hf_offload.py @@ -0,0 +1,144 @@ +"""Integration test: Megatron HF-export -> WeightManager buffer -> HF tensors. + +Validates the full PR2/PR3 path without RaaS: + 1. MegatronEngine.get_hf_weight_metadata() sizes the buffer. + 2. WeightManager.offload(export_hf_named_params()) streams HF tensors into + the shared-memory double buffer (writer rank only). + 3. We read the buffer back, reinterpret per tensors_meta, and assert it + equals the reference HF checkpoint bit-for-bit. + +This proves the bytes the sender will TCP to RaaS (full mode) are correct, +and that they live in HF layout (so the sender's HF-space delta is valid). + +Run: + torchrun --nproc_per_node= \ + astraflow/core/weight_manager/tests/test_megatron_hf_offload.py \ + --model /shared/models/Qwen3-0.6B --tp 2 --pp 1 --ep 1 +""" + +from __future__ import annotations + +import argparse +import glob +import os +import sys + +import torch +import torch.distributed as dist + + +def _ref(model_path): + from safetensors.torch import load_file + + ref = {} + for f in sorted(glob.glob(os.path.join(model_path, "*.safetensors"))): + ref.update(load_file(f)) + return ref + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--model", required=True) + ap.add_argument("--tp", type=int, default=1) + ap.add_argument("--pp", type=int, default=1) + ap.add_argument("--ep", type=int, default=1) + args = ap.parse_args() + + from astraflow.train_worker.api.alloc_mode import ParallelStrategy + from astraflow.train_worker.api.cli_args import TrainEngineConfig + from astraflow.train_worker.api.io_struct import FinetuneSpec + from astraflow.train_worker.engine.megatron_engine import MegatronEngine + + world = int(os.environ["WORLD_SIZE"]) + dp = world // (args.tp * args.pp * args.ep) + + engine = MegatronEngine(TrainEngineConfig(path=args.model, dtype="bfloat16")) + engine.create_process_group( + parallel_strategy=ParallelStrategy( + data_parallel_size=dp, + tensor_parallel_size=args.tp, + pipeline_parallel_size=args.pp, + expert_parallel_size=args.ep, + ) + ) + engine.initialize( + addr=None, + ft_spec=FinetuneSpec(total_train_epochs=1, dataset_size=1, train_batch_size=1), + ) + + rank = dist.get_rank() + + # 1. Metadata (lockstep on all ranks). + hf_meta = engine.get_hf_weight_metadata() + + # 2. Build a minimal WeightManager-like buffer write on the writer rank, + # reusing the real _offload_megatron_hf logic. We construct a real + # WeightManager but stub out the sender (no subprocess) by writing + # into a plain CPU tensor of the right size. + from math import prod + + sizes = [(n, prod(sh) * (2 if dt == "bfloat16" else 4)) for n, (sh, dt) in hf_meta] + total = sum(s for _, s in sizes) + + # Only rank 0 holds the "buffer" and checks; others just drive collectives. + is_writer = rank == 0 + buf = torch.zeros(2 * total, dtype=torch.uint8) if is_writer else None + + # Stream export and write to buf[half 0] in order (mirrors + # WeightManager._offload_megatron_hf with inactive_buf_idx=0). + offset = 0 + written = {} + for name, tensor in engine.export_hf_named_params(): + nbytes = tensor.numel() * tensor.element_size() + if is_writer: + u8 = tensor.contiguous().view(-1).view(torch.uint8) + buf[offset : offset + nbytes].copy_(u8) + written[name] = (offset, list(tensor.shape), str(tensor.dtype)) + offset += nbytes + + result = 0 + if is_writer: + ref = _ref(args.model) + import json + + tie = json.load(open(os.path.join(args.model, "config.json"))).get( + "tie_word_embeddings", False + ) + # Read back each tensor from the buffer per metadata and compare. + off = 0 + nbad = 0 + nchk = 0 + for name, (shape, dt) in hf_meta: + numel = prod(shape) + nbytes = numel * (2 if dt == "bfloat16" else 4) + raw = buf[off : off + nbytes] + tdtype = torch.bfloat16 if dt == "bfloat16" else torch.float32 + t = raw.view(tdtype).view(*shape) if shape else raw.view(tdtype) + off += nbytes + if name not in ref: + if not (tie and name == "lm_head.weight"): + print(f"[FAIL] {name} not in ref", flush=True) + nbad += 1 + continue + r = ref[name].to(torch.bfloat16) + if not torch.equal(t, r): + md = (t.float() - r.float()).abs().max().item() + print(f"[FAIL] {name} max|diff|={md:.3e}", flush=True) + nbad += 1 + nchk += 1 + print( + f"\n=== buffer roundtrip: total_bytes={total} checked={nchk} bad={nbad} ===", + flush=True, + ) + result = 0 if nbad == 0 else 1 + + res_t = torch.tensor([result], device=f"cuda:{os.environ.get('LOCAL_RANK', 0)}") + dist.all_reduce(res_t, op=dist.ReduceOp.MAX) + if rank == 0: + print("PASS" if res_t.item() == 0 else "FAIL", flush=True) + engine.destroy() + return int(res_t.item()) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/astraflow/core/weight_manager/weight_manager.py b/astraflow/core/weight_manager/weight_manager.py index f8772a7..658ab4a 100644 --- a/astraflow/core/weight_manager/weight_manager.py +++ b/astraflow/core/weight_manager/weight_manager.py @@ -34,6 +34,17 @@ logger = logging.getLogger(__name__) +_DTYPE_SIZES = { + "float32": 4, "float16": 2, "bfloat16": 2, + "int64": 8, "int32": 4, "int16": 2, "int8": 1, "uint8": 1, +} + + +def _nbytes(shape: List[int], dtype: str) -> int: + from math import prod + + return int(prod(shape)) * _DTYPE_SIZES.get(dtype, 2) + class WeightManager: """Single owner of all weight transfer state and logic. @@ -85,6 +96,7 @@ def initialize( global_rank: int, lora_config: dict | None = None, megatron_metadata: Optional[dict] = None, + megatron_hf_meta: Optional[list] = None, dp_replicate_rank: int = 0, ) -> None: """Initialize buffers and start sender agent. @@ -103,11 +115,19 @@ def initialize( ``target_modules``). Forwarded to the sender agent so the RaaS receiver can save weights in PEFT adapter format. megatron_metadata : dict, optional - If provided, enables Megatron shard-direct mode. Contains - ``tp_size``, ``tp_rank``, ``dp_rank``, ``shard_specs`` (per-param - TP metadata), and ``conversion_config`` (for CPU-side HF conversion - in the sender agent). Buffer layout is computed from full - (gathered) param sizes; each TP rank writes only its shard. + (Legacy TP-only shard-direct mode — superseded by + ``megatron_hf_meta``.) If provided, the sender reassembles raw + TP shards into HF format on CPU. Only correct for PP=1/EP=1 and + cannot compute deltas in HF space; kept for backward compat. + megatron_hf_meta : list, optional + Megatron **HF-export** mode (preferred). The ordered HF weight + layout ``[(hf_name, (shape, dtype_str)), ...]`` from + ``hf_weight_metadata``. The buffer is sized for the full HF + model and ``offload`` writes already-converted HF tensors (from + ``export_hf_named_params``) on the DP-head rank. Because the + buffer holds HF bytes, the sender's standard full/delta path is + correct under any TP/PP/EP/VPP combination — see + ``docs/en/architecture/megatron-weight-sync.md``. dp_replicate_rank : int HSDP replica group index. 0 = primary replica (owns the shm buffer and offloads weights). >0 = secondary replica (skips @@ -117,8 +137,18 @@ def initialize( self._global_rank = global_rank self._hsdp_replica_rank = dp_replicate_rank self._megatron_metadata = megatron_metadata - - if megatron_metadata is not None: + self._megatron_hf_meta = megatron_hf_meta + + if megatron_hf_meta is not None: + # HF-export mode: buffer holds the full HF model in HF layout. + # The sender treats it exactly like FSDP (no reassembly, delta + # in HF space), so megatron_metadata stays None for the sender. + meta_size = [ + (name, _nbytes(shape, dtype)) + for name, (shape, dtype) in megatron_hf_meta + ] + tensors_meta = list(megatron_hf_meta) + elif megatron_metadata is not None: meta_size, tensors_meta = self._compute_megatron_buffer_layout( megatron_metadata["shard_specs"] ) @@ -128,11 +158,15 @@ def initialize( # Only the primary HSDP replica (replica_rank=0) runs the sender # agent and owns the shm buffer. Secondary replicas skip entirely. + # In HF-export mode the buffer already holds HF bytes, so the sender + # runs the plain (FSDP) path — pass megatron_metadata=None. if local_rank == 0 and dp_replicate_rank == 0: self._start_sender_agent( meta_size, tensors_meta, lora_config=lora_config, - megatron_metadata=megatron_metadata, + megatron_metadata=( + None if megatron_hf_meta is not None else megatron_metadata + ), ) self._broadcast_shm_buffer() @@ -346,6 +380,13 @@ def offload( dict Weight transfer metrics for wandb logging. Empty on non-rank-0. """ + # Megatron HF-export mode: ``named_params`` is a fresh generator that + # yields gathered HF tensors. It must be streamed (not list()-ed) and + # iterated in lockstep on every rank (it runs TP/PP/EP collectives), + # but only the writer rank copies into the buffer. + if self._megatron_hf_meta is not None: + return self._offload_megatron_hf(named_params, version, rank, world_size) + params_list = list(named_params) # Guard: wait if previous delta is still reading the inactive half. @@ -429,6 +470,73 @@ def offload( return metrics + def _offload_megatron_hf( + self, + hf_named_params: Iterator[Tuple[str, torch.Tensor]], + version: int, + rank: int, + world_size: int, + ) -> dict: + """Stream gathered HF tensors into the buffer (Megatron HF-export mode). + + Every rank iterates ``hf_named_params`` in lockstep — it drives the + TP/PP/EP collectives inside ``export_hf_named_params`` — but only the + writer rank (global rank 0, which owns the shm buffer) copies the + yielded tensors into the inactive half, in the fixed order that + matches ``megatron_hf_meta``. Because the bytes are HF-layout, the + sender's standard full/delta path is correct (delta in HF space). + """ + t_guard_start = _time.perf_counter() + self._wait_previous_delta() + if dist.is_initialized(): + dist.barrier() + t0 = _time.perf_counter() + guard_time = t0 - t_guard_start + + is_writer = self._buffer is not None and self._local_rank == 0 + half_base = self._inactive_buf_idx * self._single_buffer_length + offset = 0 + n_written = 0 + for _name, tensor in hf_named_params: + nbytes = tensor.numel() * tensor.element_size() + if is_writer: + t_u8 = tensor.contiguous().view(-1).view(torch.uint8) + self._buffer[half_base + offset: half_base + offset + nbytes].copy_( + t_u8 + ) + n_written += 1 + offset += nbytes + t1 = _time.perf_counter() + + if dist.is_initialized(): + dist.barrier() + t2 = _time.perf_counter() + + ack = self._notify_buffer_ready(version) + t3 = _time.perf_counter() + + metrics: dict = {} + if rank == 0: + metrics = { + "weight_transfer/offload_guard_time": guard_time, + "weight_transfer/offload_copy_time": t1 - t0, + "weight_transfer/offload_barrier_time": t2 - t1, + "weight_transfer/offload_notify_time": t3 - t2, + "weight_transfer/offload_total_time": t3 - t_guard_start, + } + if self._last_delta_metrics: + metrics.update(self._last_delta_metrics) + self._last_delta_metrics = None + print( + f"[WeightManager] offload mode=megatron_hf_export, " + f"wrote={n_written} tensors, total_bytes={offset}, " + f"guard={guard_time:.3f}s, copy={t1 - t0:.3f}s, " + f"barrier={t2 - t1:.3f}s, notify={t3 - t2:.3f}s, " + f"total={t3 - t_guard_start:.3f}s", + flush=True, + ) + return metrics + # ------------------------------------------------------------------ # Copy strategies # ------------------------------------------------------------------ diff --git a/astraflow/train_worker/engine/megatron_engine.py b/astraflow/train_worker/engine/megatron_engine.py index 1fd5a87..b3c82de 100644 --- a/astraflow/train_worker/engine/megatron_engine.py +++ b/astraflow/train_worker/engine/megatron_engine.py @@ -368,99 +368,37 @@ def update_weights(self, meta: WeightUpdateMeta): "Use TCP-based weight transfer instead." ) - def get_megatron_shard_metadata(self) -> dict: - """Return TP shard metadata for shard-direct weight transfer. + def export_hf_named_params(self) -> Iterator[tuple[str, torch.Tensor]]: + """Stream ``(hf_name, full HF-layout CPU tensor)`` for weight sync. - Returns a serializable dict that WeightManager passes to the sender - agent. The sender agent uses it for CPU-side reassembly of TP shards - into HF-format params before serving to RaaS. + Reconstructs the global model from Megatron's TP/PP/EP/ETP/VPP layout + (via mbridge) and yields HF-named tensors one at a time — OOM-safe + for large / MoE models. Must be iterated in lockstep on every rank + (it runs collectives); the WeightManager decides which rank writes. - Currently requires PP=1 and EP=1. + See ``astraflow.train_worker.models.mcore.weight_export`` and + ``docs/en/architecture/megatron-weight-sync.md``. """ - from astraflow.train_worker.utils.megatron import get_named_parameters + from astraflow.train_worker.models.mcore.weight_export import ( + export_hf_named_params, + ) - pp_size = mpu.get_pipeline_model_parallel_world_size() - ep_size = mpu.get_expert_model_parallel_world_size() - if pp_size > 1: - raise NotImplementedError( - f"Megatron weight transfer does not support PP>1 yet (pp_size={pp_size})." - ) - if ep_size > 1: - raise NotImplementedError( - f"Megatron weight transfer does not support EP>1 yet (ep_size={ep_size})." - ) + self._ensure_ready() + yield from export_hf_named_params(self.bridge, self.model) - tp_size = mpu.get_tensor_model_parallel_world_size() - vocab_size = self.hf_config.vocab_size - model_type = getattr(self.hf_config, "model_type", "") - - # Collect per-param shard specs - shard_specs = [] - num_experts = getattr(self.tf_config, "num_moe_experts", None) - for mcore_name, param in get_named_parameters(self.model, num_experts): - is_tp = getattr(param, "tensor_model_parallel", False) - is_duplicated = getattr(param, "parallel_mode", None) == "duplicated" - is_sharded = is_tp and not is_duplicated - partition_dim = getattr(param, "partition_dim", 0) if is_sharded else 0 - - shard_shape = list(param.data.shape) - if is_sharded: - full_shape = list(param.data.shape) - full_shape[partition_dim] *= tp_size - else: - full_shape = list(param.data.shape) - - # Detect special param types for reassembly - is_glu = "linear_fc1.weight" in mcore_name - is_fc2_bug = ( - "linear_fc2.weight" in mcore_name - and is_sharded - and partition_dim == 0 - ) + def get_hf_weight_metadata(self) -> list[tuple[str, tuple[list[int], str]]]: + """Return the ordered HF weight layout ``[(name, (shape, dtype)), ...]``. - # Check if this is an embedding/output_layer that needs vocab unpadding - needs_vocab_unpad = ( - mcore_name in ( - "module.module.embedding.word_embeddings.weight", - "module.module.output_layer.weight", - ) - and full_shape[0] > vocab_size - ) + Used by WeightManager to size the transfer buffer and by RaaS to + pre-allocate. Runs the same collectives as ``export_hf_named_params`` + (call in lockstep on every rank). + """ + from astraflow.train_worker.models.mcore.weight_export import ( + hf_weight_metadata, + ) - shard_specs.append({ - "mcore_name": mcore_name, - "is_sharded": is_sharded, - "partition_dim": partition_dim, - "shard_shape": shard_shape, - "full_shape": full_shape, - "dtype": str(param.dtype).split(".")[-1], - "is_glu": is_glu, - "is_fc2_bug": is_fc2_bug, - "needs_vocab_unpad": needs_vocab_unpad, - }) - - # Conversion config for sender agent (subset of TransformerConfig fields) - try: - kv_channels = self.tf_config.kv_channels - except AttributeError: - kv_channels = None - - conversion_config = { - "model_type": model_type, - "hidden_size": self.tf_config.hidden_size, - "num_attention_heads": self.tf_config.num_attention_heads, - "num_query_groups": self.tf_config.num_query_groups, - "kv_channels": kv_channels, - "vocab_size": vocab_size, - } - - return { - "tp_size": tp_size, - "tp_rank": mpu.get_tensor_model_parallel_rank(), - "dp_rank": mpu.get_data_parallel_rank(), - "shard_specs": shard_specs, - "conversion_config": conversion_config, - } + self._ensure_ready() + return hf_weight_metadata(self.bridge, self.model) def set_version(self, version: int): self._version = version diff --git a/astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py b/astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py new file mode 100644 index 0000000..fc7fa8e --- /dev/null +++ b/astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py @@ -0,0 +1,152 @@ +"""Equivalence test for Megatron -> HF weight export. + +Loads an HF checkpoint into a Megatron GPTModel under a chosen parallel +strategy, exports it back to HF via ``export_hf_named_params``, and asserts +the reconstructed tensors match the original HF safetensors bit-for-bit +(bf16). This is the PR1 acceptance gate. + +Run (torchrun, multi-GPU): + torchrun --nproc_per_node= \ + astraflow/train_worker/models/mcore/tests/test_hf_export_equiv.py \ + --model /shared/models/Qwen3-0.6B --tp 2 --pp 1 --ep 1 + +Exit code 0 = all tensors match. Non-zero = mismatch (details on rank 0). +""" + +from __future__ import annotations + +import argparse +import os +import sys + +import torch +import torch.distributed as dist + + +def _load_reference_hf(model_path: str) -> dict[str, torch.Tensor]: + """Load the original HF checkpoint tensors (bf16) from safetensors.""" + import glob + + from safetensors.torch import load_file + + ref: dict[str, torch.Tensor] = {} + files = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) + if not files: + raise FileNotFoundError(f"no .safetensors in {model_path}") + for f in files: + ref.update(load_file(f)) + return ref + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--model", required=True) + ap.add_argument("--tp", type=int, default=1) + ap.add_argument("--pp", type=int, default=1) + ap.add_argument("--ep", type=int, default=1) + ap.add_argument("--atol", type=float, default=0.0, help="0 = exact bf16 match") + args = ap.parse_args() + + from astraflow.train_worker.api.alloc_mode import ParallelStrategy + from astraflow.train_worker.api.cli_args import TrainEngineConfig + from astraflow.train_worker.engine.megatron_engine import MegatronEngine + from astraflow.train_worker.models.mcore.weight_export import ( + export_hf_named_params, + ) + + world = int(os.environ["WORLD_SIZE"]) + dp = world // (args.tp * args.pp * args.ep) + assert dp >= 1, ( + f"world={world} too small for tp*pp*ep={args.tp * args.pp * args.ep}" + ) + + cfg = TrainEngineConfig(path=args.model, dtype="bfloat16") + # No optimizer -> inference-only engine, faster init. + engine = MegatronEngine(cfg) + strategy = ParallelStrategy( + data_parallel_size=dp, + tensor_parallel_size=args.tp, + pipeline_parallel_size=args.pp, + expert_parallel_size=args.ep, + ) + engine.create_process_group(parallel_strategy=strategy) + + from astraflow.train_worker.api.io_struct import FinetuneSpec + + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=1, train_batch_size=1) + engine.initialize(addr=None, ft_spec=ft_spec) + + rank = dist.get_rank() + is_writer = rank == 0 + + ref = _load_reference_hf(args.model) if is_writer else None + + n_checked = 0 + n_mismatch = 0 + seen: set[str] = set() + for name, tensor in export_hf_named_params(engine.bridge, engine.model): + if not is_writer: + continue + seen.add(name) + if name not in ref: + print(f"[FAIL] exported tensor not in reference: {name}", flush=True) + n_mismatch += 1 + continue + r = ref[name].to(torch.bfloat16) + t = tensor.to(torch.bfloat16) + if list(t.shape) != list(r.shape): + print( + f"[FAIL] shape {name}: export {list(t.shape)} vs ref {list(r.shape)}", + flush=True, + ) + n_mismatch += 1 + continue + if args.atol == 0.0: + ok = torch.equal(t, r) + else: + ok = torch.allclose(t.float(), r.float(), atol=args.atol, rtol=0) + if not ok: + md = (t.float() - r.float()).abs().max().item() + print(f"[FAIL] values {name}: max|diff|={md:.3e}", flush=True) + n_mismatch += 1 + n_checked += 1 + + if is_writer: + import json + + with open(os.path.join(args.model, "config.json")) as f: + tie = json.load(f).get("tie_word_embeddings", False) + missing = set(ref.keys()) - seen + # Benign non-exports: + # - rotary/inv_freq buffers (not weights); + # - lm_head.weight when embeddings are tied (mbridge emits only + # embed_tokens; the inference engine ties internally). + benign = {k for k in missing if "rotary" in k or "inv_freq" in k} + if tie and "lm_head.weight" in missing: + benign.add("lm_head.weight") + hard_missing = missing - benign + print( + f"\n=== export equivalence: checked={n_checked} " + f"mismatch={n_mismatch} missing={len(hard_missing)} " + f"benign_missing={len(benign)} ===", + flush=True, + ) + if hard_missing: + print( + f"[FAIL] reference keys never exported: {sorted(hard_missing)[:10]}", + flush=True, + ) + result = 0 if (n_mismatch == 0 and not hard_missing) else 1 + else: + result = 0 + + res_t = torch.tensor([result], device=f"cuda:{os.environ.get('LOCAL_RANK', 0)}") + dist.all_reduce(res_t, op=dist.ReduceOp.MAX) + if dist.get_rank() == 0: + print("PASS" if res_t.item() == 0 else "FAIL", flush=True) + engine.destroy() + return int(res_t.item()) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/astraflow/train_worker/models/mcore/weight_export.py b/astraflow/train_worker/models/mcore/weight_export.py new file mode 100644 index 0000000..5006ae3 --- /dev/null +++ b/astraflow/train_worker/models/mcore/weight_export.py @@ -0,0 +1,151 @@ +"""Streaming Megatron -> HuggingFace weight export for online weight sync. + +Reconstructs the *global* model from Megatron's sharded layout +(TP / PP / EP / ETP / VPP) and yields HuggingFace-named, HF-layout CPU +tensors **one bucket at a time**, so a large or MoE model is never +materialized in full. + +This is the single source of truth for "Megatron weights -> HF" used by the +online weight-sync path (``WeightManager.offload``). See +``docs/en/architecture/megatron-weight-sync.md`` for the design and the +HF-space delta invariant. + +Implementation note +------------------- +The heavy lifting (PP ``all_gather_object`` + broadcast, EP/ETP/TP +all-gather, local->global expert-id rewrite, and mcore->HF name/layout +conversion) is delegated to ``mbridge``'s ``Bridge.export_weights`` — the +same bridge the engine already uses to load (``_load_model_from_hf``) and +save (``_save_model_to_hf``). It is a battle-tested ``per_tensor_generator`` +(equivalent to verl's ``per_tensor_generator`` and slime's +``HfWeightIteratorDirect``) that yields ``(hf_name, full_gpu_tensor)``. + +We add the AstraFlow-specific consumer concerns on top: move each tensor to +CPU (the transfer buffer is CPU shared memory), group into byte-bounded +buckets, and a metadata-only mode for buffer sizing. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +import torch + +from astraflow.train_worker.utils import logging + +logger = logging.getLogger(__name__) + +# Default gather-bucket size, in bytes, measured on the *converted HF* +# tensors. mbridge gathers one source param at a time internally; this only +# bounds how many converted tensors we batch before handing them to the +# consumer (so the consumer can copy a run of tensors without per-tensor +# Python overhead). One bucket is alive at a time. +DEFAULT_BUCKET_BYTES = 512 << 20 # 512 MiB + + +def export_hf_named_params( + bridge, + models: list, + *, + to_cpu: bool = True, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield ``(hf_name, full_unsharded_tensor)`` for every model parameter. + + Reconstructs the global model from Megatron's TP/PP/EP/ETP/VPP layout + via ``bridge.export_weights`` and yields HF-named tensors. Only one + gathered tensor is resident at a time (plus transient collective + buffers), so this is OOM-safe for large / MoE models. + + Parameters + ---------- + bridge : + The ``mbridge`` bridge for this model (``engine.bridge``). Already + configured with the model's ``TransformerConfig`` and dtype. + models : + The engine's model chunk list (``_MegatronModelList``): VPP chunks, + each typically ``DistributedDataParallel``-wrapped. ``mbridge`` + unwraps them internally. + to_cpu : + Move each yielded tensor to CPU (default). The transfer buffer is + CPU shared memory, so this is the normal path. Set False only for + callers that consume on-GPU. + + Yields + ------ + tuple[str, torch.Tensor] + HF parameter name (e.g. ``model.layers.0.self_attn.q_proj.weight``) + and the full (unsharded) tensor, contiguous, on CPU when + ``to_cpu``. + + Notes + ----- + Every rank must call this in lockstep: ``export_weights`` runs + collectives (PP all_gather_object + broadcast, TP/EP/ETP all_gather) + across all model-parallel ranks. The yielded values are identical on + every rank in a model-parallel group, so the caller decides which rank + actually writes them to the buffer (the DP/PP/TP head). + """ + for hf_name, param in bridge.export_weights(models): + tensor = param.detach() + if to_cpu: + # bf16/contiguous on CPU — pinned-buffer copy happens in the + # consumer; .contiguous() guards against non-contiguous views + # produced by QKV/gate-up splits in the converter. + tensor = tensor.to("cpu", copy=False).contiguous() + else: + tensor = tensor.contiguous() + yield hf_name, tensor + + +def iter_param_buckets( + named_params: Iterator[tuple[str, torch.Tensor]], + bucket_bytes: int = DEFAULT_BUCKET_BYTES, +) -> Iterator[list[tuple[str, torch.Tensor]]]: + """Group a ``(name, tensor)`` stream into byte-bounded buckets. + + Yields lists whose cumulative tensor bytes stay under ``bucket_bytes`` + (a single tensor larger than the cap forms its own bucket). Lets the + consumer amortize per-tensor overhead while keeping only one bucket of + tensors alive at a time. + """ + bucket: list[tuple[str, torch.Tensor]] = [] + cur = 0 + for name, tensor in named_params: + nbytes = tensor.numel() * tensor.element_size() + if bucket and cur + nbytes > bucket_bytes: + yield bucket + bucket = [] + cur = 0 + bucket.append((name, tensor)) + cur += nbytes + if bucket: + yield bucket + + +def hf_weight_metadata( + bridge, + models: list, +) -> list[tuple[str, tuple[list[int], str]]]: + """Return the ordered HF weight layout: ``[(name, (shape, dtype_str)), ...]``. + + Drives the same ``export_weights`` generator but keeps only shape/dtype + (dropping tensor storage as it goes), so the full model is never + resident. Consumed by ``WeightManager`` to size the transfer buffer and + by the RaaS receiver (as ``tensors_meta``) to pre-allocate — both ends + then agree on layout and order. + + Must be called in lockstep on every rank (it runs the same collectives + as ``export_hf_named_params``). + """ + meta: list[tuple[str, tuple[list[int], str]]] = [] + for hf_name, param in bridge.export_weights(models): + dtype = str(param.dtype).split(".")[-1] + meta.append((hf_name, (list(param.shape), dtype))) + del param + logger.info( + "[weight_export] HF metadata: %d tensors, first=%s last=%s", + len(meta), + meta[0][0] if meta else "?", + meta[-1][0] if meta else "?", + ) + return meta diff --git a/astraflow/train_worker/trainer/ppo_trainer.py b/astraflow/train_worker/trainer/ppo_trainer.py index b90af8c..16ab8e9 100644 --- a/astraflow/train_worker/trainer/ppo_trainer.py +++ b/astraflow/train_worker/trainer/ppo_trainer.py @@ -171,12 +171,16 @@ def _is_megatron(self) -> bool: return isinstance(self.actor, MegatronEngine) def _get_named_params_for_offload(self): - """Return raw named parameters for WeightManager.offload(). + """Return the (name, tensor) stream for WeightManager.offload(). - For both Megatron and FSDP: yields raw model.named_parameters(). - Megatron params are TP-sharded; WeightManager._copy_megatron_shards - handles the shard-direct copy using tp_rank offsets. + - Megatron: a fresh ``export_hf_named_params`` generator that yields + gathered HF-layout tensors (handles TP/PP/EP/VPP). WeightManager + streams it into the HF buffer on the writer rank. + - FSDP: raw ``model.named_parameters()`` (DTensor shards handled by + WeightManager's shard-copy / all-gather paths). """ + if self._is_megatron: + return self.actor.export_hf_named_params() try: return self.actor.model.named_parameters(remove_duplicate=False) except TypeError: @@ -255,9 +259,13 @@ def _init_weight_manager(self) -> None: "target_modules": list(peft_cfg.target_modules), } - megatron_metadata = None + # Megatron HF-export mode: the buffer is sized from the full HF + # weight layout, and offload streams gathered HF tensors into it. + # This keeps the sender/RaaS path identical to FSDP (delta in HF + # space) and works under any TP/PP/EP/VPP combination. + megatron_hf_meta = None if self._is_megatron: - megatron_metadata = self.actor.get_megatron_shard_metadata() + megatron_hf_meta = self.actor.get_hf_weight_metadata() # Determine HSDP replica rank (0 = primary, >0 = secondary). dp_replicate_rank = 0 @@ -279,10 +287,15 @@ def _init_weight_manager(self) -> None: dp_replicate_rank=dp_replicate_rank, ) else: - named_params = self._get_named_params_for_offload() + # In Megatron HF-export mode the layout comes from + # megatron_hf_meta, so named_params is unused at init time. + named_params = ( + iter(()) if self._is_megatron + else self._get_named_params_for_offload() + ) self.weight_manager.initialize( named_params, local_rank, global_rank, - megatron_metadata=megatron_metadata, + megatron_hf_meta=megatron_hf_meta, dp_replicate_rank=dp_replicate_rank, ) logger.info( diff --git a/docker/Dockerfile.sglang b/docker/Dockerfile.sglang index b2a1ef9..77a642f 100644 --- a/docker/Dockerfile.sglang +++ b/docker/Dockerfile.sglang @@ -33,4 +33,31 @@ RUN uv pip install -e ".[sglang]" # so install it explicitly with --no-build-isolation. RUN uv pip install "flash-attn==2.8.3" --no-build-isolation +# --- Megatron backend extras --- +# The Megatron training backend (engine/megatron_engine.py) uses +# Transformer Engine for fused LayerNorm + sequence parallelism, and benefits +# from apex's fused LayerNorm / Adam kernels. megatron-core / mbridge are +# already pulled in by the base install; these two are the heavy compiled +# deps that complete the stack. +ENV CUDA_HOME=/usr/local/cuda \ + NVTE_FRAMEWORK=pytorch \ + TORCH_CUDA_ARCH_LIST="8.0;8.9;9.0" + +# Transformer Engine — build against the already-installed torch. +RUN uv pip install --no-build-isolation "transformer-engine[pytorch]>=2.13.0,<2.14" + +# apex (optional perf; Megatron falls back to Torch Norm / torch Adam if absent). +# - apex reads APEX_CPP_EXT / APEX_CUDA_EXT env flags to select extensions. +# - The base image ships a CUDA toolkit whose minor version may differ from +# torch's CUDA; neutralize apex's strict version guard (the mismatch is a +# safe minor 12.x difference). +# - FORCE_CUDA=1 lets the CUDA extensions build without a visible GPU. +# - `|| echo` keeps a apex failure non-fatal (TE is the must-have). +RUN git clone --depth 1 https://github.com/NVIDIA/apex.git /tmp/apex && \ + cd /tmp/apex && \ + sed -i 's/^def check_cuda_torch_binary_vs_bare_metal(cuda_dir):/def check_cuda_torch_binary_vs_bare_metal(cuda_dir):\n return/' setup.py && \ + FORCE_CUDA=1 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 \ + uv pip install -v --no-build-isolation . && \ + rm -rf /tmp/apex || echo "[apex] build failed — continuing without apex (Torch Norm fallback)" + CMD ["/bin/bash"] diff --git a/docs/en/architecture/megatron-weight-sync.md b/docs/en/architecture/megatron-weight-sync.md new file mode 100644 index 0000000..fc3a4b2 --- /dev/null +++ b/docs/en/architecture/megatron-weight-sync.md @@ -0,0 +1,156 @@ +# Megatron Weight Synchronization + +This page describes how the Megatron-LM training backend exports its +weights to RaaS, and the invariants that keep the **sparse / delta** +weight-update path correct under tensor (TP), pipeline (PP), expert +(EP), expert-tensor (ETP), virtual-pipeline (VPP), and context (CP) +parallelism. + +It complements [WeightManager](weight-manager.md) and +[Delta Weight Transfer](delta-weight-transfer.md), which describe the +backend-agnostic transport. **Read those first.** + +## The problem + +Megatron stores each parameter sharded across TP/PP/EP ranks, fused +(QKV in one `linear_qkv`, gate+up in one `linear_fc1`), and vocab-padded +— a layout that bears no resemblance to the HuggingFace checkpoint names +and byte layout that SGLang / vLLM expect (`model.layers.N.self_attn.q_proj.weight`, +…). RaaS only understands HF layout. + +The transport layer (`WeightManager` + sender agent + RaaS receiver) is +deliberately **backend-agnostic**: it moves an opaque, fixed-order CPU +byte buffer and, in delta mode, ships only the bytes that changed +between two versions of that buffer. For this to be correct, the bytes +in the buffer **must be in the same layout that RaaS applies them in**. + +For FSDP this is automatic — the buffer already holds HF-layout tensors. +For Megatron it is the central design constraint. + +## Design invariant + +> **The trainer always writes HF-named, HF-layout, full-model tensors into +> the transfer buffer. Sparsity / delta is always computed in HF byte +> space, over a double buffer. The RaaS receive path never sees a +> backend-specific layout.** + +Concretely, the Megatron backend reconstructs the global model from its +sharded layout and converts it to HF on the GPU, **before** anything +reaches the transfer buffer. The sender agent and RaaS receiver then +treat Megatron exactly like FSDP. + +This makes the delta correct **by construction**: both the old and new +buffer halves hold HF bytes, so a bytewise diff produces indices that +the receiver can scatter directly into its HF buffer. + +> ⚠️ **Historical bug (fixed by this design).** An earlier Megatron path +> wrote *raw mcore-layout shards* into the buffer and reassembled to HF in +> a separate, single-buffered region in the sender — but computed the +> delta over the *mcore-layout* buffer. mcore byte offsets ≠ HF byte +> offsets (fused QKV vs split, fused gate/up vs split, vocab padding), so +> applying an mcore-space delta to RaaS's HF-space base silently corrupted +> weights. Always diff in HF space. + +## The per-tensor generator + +The reconstruction is a **streaming generator** that yields +`(hf_name, full_unsharded_cpu_tensor)` one bucket at a time. Only one +bucket of fully-gathered tensors is alive at any moment, so a 100B / MoE +model never materializes in full (no OOM). + +```python +# astraflow/train_worker/models/mcore/weight_export.py +def export_hf_named_params( + models, # _MegatronModelList: VPP chunks, DDP-wrapped + tf_config, + hf_config, + bucket_bytes, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield (hf_name, full HF-layout CPU tensor) bucket by bucket.""" +``` + +Per parameter, in order, it performs the minimal collectives: + +1. **Naming + PP/EP offsets** — `utils.megatron.get_named_parameters` + already maps local mcore names to *global* names (adds the PP layer + offset and EP expert offset), iterating VPP chunks. +2. **PP gather** — `all_gather_object` of metadata across the pipeline + group, then broadcast each owner stage's tensor so the DP-head rank + set collectively holds every global parameter. Embeddings live on the + first stage, `output_layer` / final norm on the last. +3. **TP / ETP gather** — `utils.megatron.all_gather_param` all-gathers + the shards along `partition_dim` and concatenates, handling the GLU + `linear_fc1` stride-2 rechunk and the grouped-MoE `linear_fc2` + `partition_dim` 0→1 quirk. +4. **EP gather** — for `.experts.` params, all-gather across the expert + group and rewrite local→global expert id. +5. **mcore → HF convert** — `utils.megatron.convert_to_hf` splits QKV + (GQA-aware), splits gate/up, renames, and drops vocab padding. +6. **Bucket + stream** — group converted tensors until `bucket_bytes` + (measured post-gather), `yield`, then free before the next bucket. + +This is the same abstraction verl (`per_tensor_generator`) and slime +(`HfWeightIteratorDirect`) converged on; the difference is the consumer. + +## How it plugs into WeightManager + +verl / slime push the generator's tensors GPU→GPU (NCCL / CUDA-IPC) into +the inference engine. AstraFlow instead **writes them into the CPU +double buffer** that the sender agent TCP-pulls: + +``` +optimizer.step() + │ + ▼ +WeightManager.offload(export_hf_named_params(...), version, ...) + │ DP-head ranks write each (hf_name, tensor) into the INACTIVE + │ half of the HF double buffer, in fixed order; non-heads barrier. + ▼ +notify_buffer_ready ──► sender swaps active/inactive + │ + ▼ sender._compute_delta() diffs HF-inactive vs + │ HF-active → indices in HF space ✓ + ▼ +RaaS pulls full or delta (unchanged from FSDP) +``` + +Buffer sizing comes from `MegatronEngine.get_hf_weight_metadata()` — a +metadata-only dry run of the generator that returns the ordered +`[(hf_name, shape, dtype), …]` list. This is the same `tensors_meta` +the RaaS receiver uses to pre-allocate, so both ends agree on layout and +order. + +## Rank participation + +Only **data-parallel head** ranks write the buffer (one writer per +model-parallel group), mirroring FSDP's primary-replica rule. The +TP/PP/EP gathers happen via collectives *before* the write, so every +DP-head holds the full HF model and writes it once. Other ranks only +participate in the gathers and the post-write barrier. + +## Configuration + +```yaml +trainer: + engine: + backend: megatron + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + actor: + megatron: + weight_export_bucket_bytes: 536870912 # 512 MiB gather bucket +``` + +`backend: megatron` is auto-selected when `pipeline_parallel_size > 1` +or `expert_parallel_size > 1`. + +## Invariants checklist (for reviewers) + +1. Trainer hands WeightManager **HF-named, HF-layout, full-model** + tensors. Backend differences end at `export_hf_named_params`. +2. Sparsity / delta is computed in **HF byte space**, on a double buffer. +3. **One bucket** of gathered tensors alive at a time — never + `full_tensor()` the whole model. +4. Only **DP-head** ranks write the buffer. +5. The **RaaS receive path is unchanged** between FSDP and Megatron. diff --git a/docs/en/index.rst b/docs/en/index.rst index 1f4fd9a..c6efc01 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -20,6 +20,7 @@ on distributed GPU clusters. architecture/raas architecture/trainer architecture/weight-manager + architecture/megatron-weight-sync .. toctree:: :maxdepth: 1 diff --git a/examples/math/qwen3-8b-megatron-delta/README.md b/examples/math/qwen3-8b-megatron-delta/README.md new file mode 100644 index 0000000..40f304a --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/README.md @@ -0,0 +1,55 @@ +# Qwen3-8B Math RL — Megatron backend, delta TCP weight transfer + +Same math RL recipe as [`qwen3-8b-m2po-delta`](../qwen3-8b-m2po-delta) (M2PO, +DeepScaleR data, ctx 16k, lr 5e-6, sparse delta weight sync) but the trainer +uses the **Megatron-LM backend** instead of FSDP. The only difference is the +`trainer_base.engine` block: + +```yaml +engine: + backend: megatron + data_parallel_size: 1 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 +``` + +This makes it a clean FSDP-vs-Megatron A/B: identical data, algorithm, and +weight-transfer path, so reward curves should track each other. + +## How weight sync works (Megatron) + +The trainer reconstructs the global model from Megatron's TP/PP/EP/VPP +layout into HuggingFace-named tensors (via `export_hf_named_params`, +backed by mbridge) and streams them into the CPU transfer buffer. Because +the buffer holds HF-layout bytes, the sparse **delta** is computed in HF +space and the RaaS receive path is identical to FSDP. See +[`docs/en/architecture/megatron-weight-sync.md`](../../../docs/en/architecture/megatron-weight-sync.md). + +## GPU layout (8 GPUs, single node) + +| Component | GPUs | Parallelism | +|-----------|------|-------------| +| RaaS (SGLang, model0) | 0,1,2,3 | DP=4 | +| Trainer model0 (Megatron) | 4,5,6,7 | TP=4 | + +## Run + +```bash +bash examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh +``` + +Or launch the three components separately (terminals 1/2/3): + +```bash +bash examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh +bash examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh +bash examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh +``` + +## Scaling to PP / MoE + +For pipeline or expert parallelism (and MoE models), set the corresponding +sizes in the `engine` block, e.g. `pipeline_parallel_size: 2` or +`expert_parallel_size: 2`. The backend auto-selects Megatron when `pp>1` or +`ep>1`. Ensure `data_parallel_size * tensor_parallel_size * +pipeline_parallel_size` equals the number of trainer GPUs. diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh b/examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh new file mode 100755 index 0000000..ae981ac --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/1_astraflow.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -euo pipefail +# [1/3] Launch AstraFlow HTTP service +# +# Usage (terminal 1): +# bash examples/math/qwen3-8b-m2po-delta/scripts/1_astraflow.sh + +export CUDA_VISIBLE_DEVICES="" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== AstraFlow HTTP Service ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "Port : ${ASTRAFLOW_PORT}" +echo "===============================" + +python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh b/examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh new file mode 100755 index 0000000..f66c5c4 --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/2_raas.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -euo pipefail +# [2/3] Launch RaaS inference server (SGLang + TCP receiver) +# +# Usage (terminal 2, after AstraFlow is ready): +# bash examples/math/qwen3-8b-m2po-delta/scripts/2_raas.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== RaaS Inference Server (SGLang + TCP receiver) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES}" +echo "Port : ${RAAS_PORT}" +echo "AstraFlow URL : ${ASTRAFLOW_URL}" +echo "=======================================================" + +python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh b/examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh new file mode 100755 index 0000000..67ffa1d --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/3_trainer_model0.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -euo pipefail +# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0) +# +# Usage (terminal 3, after AstraFlow and RaaS are ready): +# bash examples/math/qwen3-8b-m2po-delta/scripts/3_trainer_model0.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')" + +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +# sender_agent (in trainer) listens on this HTTP port +export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== Trainer model0 (TCP) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES} (Megatron TP${TRAINER0_NPROC})" +echo "AstraFlow : ${ASTRAFLOW_URL}" +echo "RaaS : ${ASTRAFLOW_RAAS_URL}" +echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================" + +torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh b/examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh new file mode 100755 index 0000000..d5ae145 --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh @@ -0,0 +1,104 @@ +#!/bin/bash +set -euo pipefail +# All-in-one launcher for AstraFlow v2 math training (Qwen3-8B, M2PO, Megatron TP4, TCP). +# +# Launches 3 processes: +# 1. AstraFlow HTTP service (CPU-only) +# 2. RaaS inference server (SGLang, SERVICE_CUDA_VISIBLE_DEVICES) +# 3. Trainer model0 (math, TRAINER_MODEL0_GPUS) +# +# Usage: +# bash examples/math/qwen3-8b-megatron-delta/scripts/run_qwen3-8b-megatron-delta.sh + +# ============================================================================= +# Part 1: Load env and settings +# ============================================================================= +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +# Defined in examples/_common/utils.sh. +astraflow_load_experiment_env + +# ============================================================================= +# Part 2: Set up env +# ============================================================================= +# GPU assignments (default: 4 GPUs for inference, 4 for training) +export SERVICE_CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export TRAINER_MODEL0_GPUS="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +# Ports / URLs (each component gets its own port) +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export WEIGHT_TRANSFER_HTTP_PORT_MODEL0="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +TRAINER0_NPROC="$(echo "${TRAINER_MODEL0_GPUS}" | awk -F',' '{print NF}')" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. +# Defined in examples/_common/utils.sh. +astraflow_setup_env + +# ============================================================================= +# Part 3: Print info and clean up +# ============================================================================= +echo "=== AstraFlow v2 (Qwen3-8B, math, M2PO, ctx16k, TCP delta) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "RaaS GPUs : ${SERVICE_CUDA_VISIBLE_DEVICES}" +echo "Trainer model0 GPUs : ${TRAINER_MODEL0_GPUS} (Megatron TP${TRAINER0_NPROC})" +echo "RaaS port : ${RAAS_PORT}" +echo "AstraFlow port : ${ASTRAFLOW_PORT}" +echo "Sender HTTP model0 : ${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================================" + +trap astraflow_cleanup_trap EXIT INT TERM + +# Kill leftover processes and shared memory from prior runs. +# Defined in examples/_common/utils.sh. +astraflow_kill_stale + +# ============================================================================= +# Part 4: Launch training +# ============================================================================= +echo "[1/3] Starting AstraFlow HTTP service..." +CUDA_VISIBLE_DEVICES="" \ + python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" & +sleep 5 + +echo "[2/3] Starting RaaS inference server (SGLang + TCP receiver)..." +CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES}" \ + python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" & +sleep 15 + +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +echo "[3/3] Starting trainer model0..." +CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS}" \ +WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" \ + torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" \ + 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3-8b-megatron-delta/yaml/experiment.yaml b/examples/math/qwen3-8b-megatron-delta/yaml/experiment.yaml new file mode 100644 index 0000000..8136452 --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/yaml/experiment.yaml @@ -0,0 +1,150 @@ +# ============================================================================ +# Experiment config — AstraFlow service + Trainer +# Experiment: math / qwen3-8b-megatron-delta +# +# Qwen3-8B math RL with M2PO, ctx 16k, lr 5e-6, delta TCP weight transfer, +# **Megatron training backend** (TP=4). Same algorithm/data as the FSDP +# qwen3-8b-m2po-delta recipe — only the trainer engine differs, which makes +# it a clean FSDP-vs-Megatron comparison. +# +# GPU layout (default, 8 GPUs): +# SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS (model0 dp=4) +# TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (Megatron, TP=4) +# ============================================================================ + +# ── Experiment: identity, model, shared settings ── +experiment: + experiment_name: astraflow-math + trial_name: qwen3-8b-megatron-delta + fileroot: ./data-experiments/${experiment.experiment_name}/${experiment.trial_name} + + model_path: "Qwen/Qwen3-8B" + tokenizer_path: "Qwen/Qwen3-8B" + seed: 1 + dtype: bfloat16 + weight_transfer_mode: tcp + weight_transfer_strategies: delta + +# ── RaaS: what to generate (inference-level config) ── +raas: + models: + model0: + backend: sglang + gconfig: + n_samples: 8 + temperature: 1.0 + max_new_tokens: 14000 + min_new_tokens: 0 + delta_full_sync_interval: 10 + +# ── AstraFlow: data pipeline ── +dataflow: + host: "0.0.0.0" + port: 8000 + + buffer: + size: 10000 + replay_size: 10000 + replay_ratio: 0 + max_staleness: 8 + filter_function: filter_zero_adv + + rollout_dataset: + dataset_fn: "astraflow.dataflow.dataset.deepscaler:get_deepscaler_rl_dataset" + max_length: 2000 + + workflow_spec: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + + eval_workflows: + math_eval: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + gconfig_overrides: + temperature: 0.6 + n_samples: 1 + + eval_datasets: + aime24: + dataset_fn: "astraflow.dataflow.dataset.aime24x4:get_aime_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + aime25: + dataset_fn: "astraflow.dataflow.dataset.aime25x4:get_aime_2025x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + math500: + dataset_fn: "astraflow.dataflow.dataset.math500:get_math500_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + +# ── Trainer base: shared config ── +trainer_base: + total_train_steps: 800 + train_batch_size: 256 + n_samples: 8 + # Megatron training backend. TP=4 across the 4 trainer GPUs (dp=1). + # backend is auto-selected as megatron when pp>1 or ep>1, but we set it + # explicitly here so a TP-only config also uses Megatron. + engine: + backend: megatron + data_parallel_size: 1 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + + actor: + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 17408 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + # PPO / M2PO algorithm + m2_threshold: 0.01 + eps_clip: 100.0 + eps_clip_higher: 100.0 + reward_scaling: 1 + reward_bias: 0 + kl_ctl: 0.00 + kl_penalty_coef: 0.001 + ppo_n_minibatches: 4 + reward_norm: { mean_level: group, std_level: group } + adv_norm: { mean_level: batch, std_level: batch } + # Megatron backend uses sensible defaults (distributed optimizer on, + # DDP wrap on). Override under actor.megatron.{ddp,...} if needed. + + ref: + mb_spec: + max_tokens_per_mb: 17408 + + recover: + mode: auto + freq_steps: 25 + + evaluator: + eval_at_start: false + freq_steps: 25 + + stats_logger: + wandb: + mode: online + id_suffix: "uid" + +# ── Trainer for model0 — only overrides ── +trainer_model0: + model_id: model0 + stats_logger: + wandb: + tags: ["m2po", "math", "astraflow-v2", "qwen3-8b", "tcp", "ctx16k", "delta", "megatron"] diff --git a/examples/math/qwen3-8b-megatron-delta/yaml/raas.yaml b/examples/math/qwen3-8b-megatron-delta/yaml/raas.yaml new file mode 100644 index 0000000..5e21cbf --- /dev/null +++ b/examples/math/qwen3-8b-megatron-delta/yaml/raas.yaml @@ -0,0 +1,33 @@ +# ============================================================================ +# RaaS config — Inference serving instance (hardware/resources) +# Experiment: math / qwen3-8b-m2po-delta +# +# Hardware: 4x GPU, TP=1 +# model0: DP=4, TP=1 +# +# Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# experiment.yaml provides: model_path, tokenizer_path, seed, dtype, models/gconfig +# ============================================================================ + +rollout: + max_concurrent_rollouts: 1024 + # Cap concurrent eval prefills to bound peak KV pressure during the + # ~3.5k-item eval burst (5 datasets x repeat=4) — default 128 OOMs sglang. + max_concurrent_evals: 64 + pause_grace_period: 3 + # Adaptive availability — drive /availability off sglang /get_load. + enable_adaptive_availability: true + target_waiting_queue_per_dp: 4 + adaptive_step_size: 4 + load_cache_ttl_ms: 100 + +engine: + model0: + backend: sglang + data_parallel_size: 4 + +sglang: + context_length: 16384 + mem_fraction_static: 0.8 + max_running_requests: null + skip_tokenizer_init: true