Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/bake-gcp-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ concurrency:

env:
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
DEFAULT_BASE_IMAGE: areal-cicd-test-20260317-3
DEFAULT_BASE_IMAGE: areal-cicd-test-20260330-386

jobs:
bake:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-areal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ concurrency:
env:
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
RUNNER_VERSION: '2.332.0'
GCP_OS_IMAGE: areal-cicd-test-20260317-3
GCP_OS_IMAGE: areal-cicd-test-20260330-386

jobs:
determine-variants:
Expand Down
77 changes: 77 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,75 @@ def __post_init__(self):
)


@dataclass
class ArchonFP8Config:
"""Archon FP8 training configuration."""

mode: str = field(
default="disabled",
metadata={
"help": "FP8 precision mode. "
"'disabled': FP8 training off (default). "
"'blockwise': blockwise 128x128 FP8 e4m3fn matmuls (requires Hopper GPU).",
"choices": ["disabled", "blockwise"],
},
)

exclude_modules: list[str] = field(
default_factory=lambda: ["output", "router", "score"],
metadata={
"help": (
"FQN substrings of nn.Linear modules to keep in BF16 (not converted to FP8). "
"Any module whose fully-qualified name contains one of these strings is skipped. "
"Meaningful values for Archon models: "
"'output' (LM head, logit precision sensitive), "
"'router' (MoE router gate, routing stability sensitive), "
"'score' (critic head, value precision sensitive). "
"Note: nn.Embedding modules (e.g. tok_embeddings) are never converted "
"regardless of this list. "
"WARNING: Setting this in YAML replaces the entire default list "
"(does not extend it). Include ALL modules you want to keep in BF16."
)
},
)

include_experts: bool = field(
default=False,
metadata={
"help": "Apply FP8 to MoE expert computation. "
"Uses per-expert blockwise FP8 matmuls via torchao."
},
)

use_triton: bool = field(
default=True,
metadata={
"help": (
"Use Triton GEMM kernel for FP8 blockwise matmuls instead of cuBLAS. "
"Currently must be True: torchao's blockwise FP8 is a prototype that uses "
"mixed per-operand scaling (1x128 activations + 128x128 weights), which "
"torch._scaled_mm does not support. The Triton kernel "
"(triton_fp8_gemm_1x128_128x128) handles this natively. "
"Revisit when torchao stabilizes mixed-mode cuBLAS dispatch."
),
},
)

def __post_init__(self):
valid_modes = {"disabled", "blockwise"}
if self.mode not in valid_modes:
raise ValueError(
f"fp8_config.mode must be one of {valid_modes}, got {self.mode!r}"
)
if self.mode != "disabled" and not self.use_triton:
raise ValueError(
"fp8_config.use_triton must be True when FP8 is enabled. "
"torchao blockwise FP8 uses mixed per-operand scaling "
"(1x128 activations + 128x128 weights) which "
"torch._scaled_mm does not support."
)


@dataclass
class ArchonEngineConfig:
"""Configuration for Archon Engine training backend."""
Expand Down Expand Up @@ -552,6 +621,14 @@ class ArchonEngineConfig:
},
)

# FP8 Training
fp8_config: ArchonFP8Config = field(
default_factory=ArchonFP8Config,
metadata={
"help": "FP8 training configuration. Set mode='blockwise' to enable."
},
)

# Deterministic mode
use_deterministic_algorithms: bool = field(
default=False,
Expand Down
18 changes: 15 additions & 3 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,10 +992,22 @@ def _ensure_ready(self) -> None:
if self.parallel_helper.sp_size > 1:
set_ulysses_sequence_parallel_group(self.sp_group)

def _get_model_name_parameters(self) -> Iterator[tuple[str, nn.Parameter]]:
def _get_model_name_parameters(
self, meta: WeightUpdateMeta
) -> Iterator[tuple[str, nn.Parameter]]:
name_params_iterator = self.model.named_parameters()
if self.is_vision_model and is_qwen_vl_model(self.model_config.model_type):
for name, value in name_params_iterator:
if meta.gen_allocation.backend == "sglang":
# SGLang 0.5.9 branch
# LLM part: "model.language_model.norm.weight" -> "model.norm.weight"
# Vision part: "model.visual.blocks.5.mlp.gate_proj.weight" -> "visual.blocks.5.mlp.gate_proj.weight"
new_name = name.replace("language_model.", "", 1)
if new_name.startswith("model.visual."):
new_name = new_name.replace("model.", "", 1)
yield new_name, value
continue
# vLLM 0.17.0 branch
new_name = name.replace("model.", "", 1)
if new_name.startswith("language_model."):
new_name = new_name.replace(
Expand Down Expand Up @@ -1186,12 +1198,12 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
# For LoRA, only iterate over trainable LoRA parameters
param_iterator = (
(name, param)
for name, param in self._get_model_name_parameters()
for name, param in self._get_model_name_parameters(meta)
if param.requires_grad
)
else:
# For full model, iterate over all parameters
param_iterator = self._get_model_name_parameters()
param_iterator = self._get_model_name_parameters(meta)

try:
for name, param in param_iterator:
Expand Down
24 changes: 21 additions & 3 deletions areal/experimental/engine/archon_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,15 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None:

engine.logger.info(f"Loading HF checkpoint from {path}")

# Get model state dict structure
from areal.experimental.models.archon.fp8_checkpoint import (
_get_scale_inv_keys,
_prepare_fp8_state_dict,
dequant_fp8_state_dict,
)

_fp8_scale_keys = _get_scale_inv_keys(path)
_is_fp8_ckpt = len(_fp8_scale_keys) > 0

options = StateDictOptions(full_state_dict=False, cpu_offload=True)
state_dict = _get_merged_state_dict(engine, options)

Expand All @@ -353,12 +361,23 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None:
if embed_key not in hf_state_dict:
hf_state_dict[embed_key] = torch.empty_like(state_dict["output.weight"])

if _is_fp8_ckpt:
hf_state_dict = _prepare_fp8_state_dict(
hf_state_dict, path, _cached_keys=_fp8_scale_keys
)

# Load using DCP with HuggingFaceStorageReader
dcp.load(
hf_state_dict,
storage_reader=engine.state_dict_adapter.get_hf_storage_reader(path),
)

if _is_fp8_ckpt:
hf_state_dict = dequant_fp8_state_dict(
hf_state_dict,
target_dtype=getattr(torch, engine.config.dtype),
)

# Convert back to Archon format
archon_state_dict = engine.state_dict_adapter.from_hf(hf_state_dict)

Expand Down Expand Up @@ -392,8 +411,7 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None:
f"Unexpected extra keys in checkpoint: {unexpected_keys}"
)

# Load into model(s)
load_options = StateDictOptions(strict=False)
load_options = StateDictOptions(strict=False, full_state_dict=False)
if engine.parallel_dims.pp_enabled:
for model_part in engine.model_parts:
set_model_state_dict(
Expand Down
32 changes: 32 additions & 0 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,30 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):

self.param_dtype = getattr(torch, self.config.dtype)

# FP8 conversion -- must run on meta device, before parallelism is applied.
# This assertion covers the training path (Phase 1A): blockwise FP8 matmuls
# require BF16 master weights. Loading an FP8 checkpoint into a BF16 model
# (Phase 1B, archon_checkpoint.py) is a separate path and may relax this.
if self.config.archon.fp8_config.mode != "disabled":
if self.config.dtype != "bfloat16":
raise ValueError(
f"FP8 training requires dtype=bfloat16 (master weights), "
f"got {self.config.dtype}"
)
from areal.experimental.models.archon.fp8 import (
enable_fp8_experts,
enable_fp8_linear,
)

fp8_cfg = self.config.archon.fp8_config
enable_fp8_linear(
self.model,
exclude_fqns=set(fp8_cfg.exclude_modules),
use_triton=fp8_cfg.use_triton,
)
if fp8_cfg.include_experts:
enable_fp8_experts(self.model, use_triton=fp8_cfg.use_triton)

# NOTE: may mutate self.config.pad_to_maximum and set env vars
# (CUBLAS_WORKSPACE_CONFIG, NCCL_ALGO, TORCH_COMPILE_DETERMINISTIC).
ac_config, enable_compile = prepare_training_config(
Expand All @@ -318,6 +342,14 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
f"Applied parallelism in {time.perf_counter() - tik:.2f} seconds"
)

if self.config.archon.fp8_config.mode != "disabled":
from areal.experimental.models.archon.fp8 import (
validate_fp8_shard_alignment,
)

parts = self.model_parts if self.parallel_dims.pp_enabled else [self.model]
validate_fp8_shard_alignment(parts)

self._materialize_and_load_weights()
self._create_optimizer(ft_spec)

Expand Down
14 changes: 13 additions & 1 deletion areal/experimental/engine/archon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def validate_zero_bubble_compatibility(
Zero-bubble schedules (split backward with retain_graph=True) conflict with
torch.compile, donated_buffer (MoE), op-level selective AC, and memory_budget AC.

Returns updated enable_compile flag.
Returns updated ``enable_compile`` flag.
"""
if get_schedule_class(pp_schedule) not in _ZERO_BUBBLE_SCHEDULES:
return enable_compile
Expand Down Expand Up @@ -314,6 +314,11 @@ def prepare_training_config(

Returns (ac_config, enable_compile). May mutate ``config.pad_to_maximum``
and set deterministic env vars.

Note: the returned ``enable_compile`` may differ from
``config.archon.enable_compile`` (zero-bubble or FP8 can disable it).
``config.archon.enable_compile`` is **not** written back — callers
must use the returned value.
"""
ac_config = build_ac_config(config, logger)
enable_compile = config.archon.enable_compile
Expand All @@ -325,6 +330,13 @@ def prepare_training_config(
ac_config=ac_config,
logger=logger,
)
if config.archon.fp8_config.mode != "disabled" and enable_compile:
logger.warning(
"FP8 blockwise training is incompatible with torch.compile. "
"Disabling torch.compile."
)
enable_compile = False

if config.archon.use_deterministic_algorithms:
setup_deterministic_mode(ac_config, enable_compile, logger)
force_pad_to_maximum(
Expand Down
66 changes: 55 additions & 11 deletions areal/experimental/inference_service/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from __future__ import annotations

import os
import sys
import time
import traceback
import uuid
from collections.abc import AsyncGenerator
from threading import Lock
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -42,7 +44,7 @@ class GatewayInferenceController:
directly over HTTP — no engine creation or RPC calls on workers.

The inference backend is determined from ``config.backend``
(currently ``"sglang"`` is supported; ``"vllm"`` is planned).
(``"sglang"`` and ``"vllm"`` are supported).
"""

# Worker role suffix for RPCGuard workers
Expand Down Expand Up @@ -235,7 +237,7 @@ async def _async_initialize(
if server_args:
for k, v in server_args.items():
if hasattr(sglang_config, k):
object.__setattr__(sglang_config, k, v)
setattr(sglang_config, k, v)
else:
logger.warning(
"SGLangConfig has no attribute %r, ignoring "
Expand All @@ -254,10 +256,29 @@ def _build_launch_cmd(host: str, port: int) -> list[str]:
)

elif inf_backend == "vllm":
raise NotImplementedError(
"vLLM backend is not yet supported by the gateway "
"rollout controller."
)
from areal.api.cli_args import vLLMConfig

vllm_config = vLLMConfig(model=cfg.model_path or cfg.tokenizer_path)
for k, v in (server_args or {}).items():
if hasattr(vllm_config, k):
setattr(vllm_config, k, v)
else:
logger.warning(
"vLLMConfig has no attribute %r, ignoring "
"server_args entry (value=%r)",
k,
v,
)

def _build_launch_cmd(host: str, port: int) -> list[str]:
return vLLMConfig.build_cmd(
vllm_config=vllm_config,
tp_size=tp_size,
pp_size=alloc.parallel.pp_size,
host=host,
port=port,
)

else:
raise ValueError(f"Unsupported inference backend: {inf_backend!r}")

Expand All @@ -279,13 +300,34 @@ def _build_launch_cmd(host: str, port: int) -> list[str]:

cmd = _build_launch_cmd(inf_host, inf_port)

fork_payload: dict[str, Any] = {
"role": "inf-server",
"worker_index": rank,
"raw_cmd": cmd,
}
if inf_backend == "vllm":
from areal.infra.utils.launcher import (
TRITON_CACHE_PATH as _TRITON_CACHE,
)
from areal.infra.utils.launcher import (
VLLM_CACHE_ROOT as _VLLM_CACHE,
)

fork_payload["env"] = {
"TRITON_CACHE_PATH": os.path.join(
os.environ.get("TRITON_CACHE_PATH", _TRITON_CACHE),
str(uuid.uuid4()),
),
"VLLM_CACHE_ROOT": os.path.join(
os.environ.get("VLLM_CACHE_ROOT", _VLLM_CACHE),
str(uuid.uuid4()),
),
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True",
}

resp = requests.post(
f"{guard_addr}/fork",
json={
"role": "inf-server",
"worker_index": rank,
"raw_cmd": cmd,
},
json=fork_payload,
timeout=30,
)
resp.raise_for_status()
Expand Down Expand Up @@ -359,6 +401,8 @@ def _build_launch_cmd(host: str, port: int) -> list[str]:
data_proxy_cmd = data_proxy_base_cmd + [
"--backend-addr",
self._inf_addrs[rank],
"--backend-type",
inf_backend or "sglang",
]
data_proxy_host, data_proxy_port = self._fork_on_guard(
guard_addr=guard_addr,
Expand Down
Loading
Loading