Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ Check the [AI Coding Assistant Guide](docs/reference/ai_assisted_dev.md) and

### Training Backends

| Backend | DP | Tensor Parallel | Sequence Parallel within TP | Context Parallel | Pipeline Parallel | Expert Parallel | 1D Sequence Packing | LoRA |
| ------------------ | ----------- | --------------- | --------------------------- | ---------------- | ----------------- | --------------- | ------------------- | ---- |
| **Megatron** | ✅ (ZeRO-1) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| **PyTorch FSDP** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ |
| **PyTorch Archon** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| Backend | DP | Tensor Parallel | Sequence Parallel within TP | Context Parallel | Pipeline Parallel | Expert Parallel | 1D Sequence Packing | LoRA |
| ------------------ | ----------- | --------------- | --------------------------- | ---------------- | ----------------- | --------------- | ------------------- | -------------------------------- |
| **Megatron** | ✅ (ZeRO-1) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (with vLLM inference backend) |
| **PyTorch FSDP** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ |
| **PyTorch Archon** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |

### Inference Backends

Expand Down Expand Up @@ -252,6 +252,7 @@ Check the [AI Coding Assistant Guide](docs/reference/ai_assisted_dev.md) and
### Reference

- [CLI Configurations](docs/en/cli_reference.md)
- [LoRA RL](docs/en/reference/lora.md)
- [Checkpointing](docs/en/reference/checkpointing.md)
- [Metrics Tracking](docs/en/reference/metrics_tracking.md)
- [Allocation Mode](docs/en/reference/alloc_mode.md)
Expand Down
30 changes: 28 additions & 2 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,7 +1968,13 @@ def __post_init__(self):
class WandBConfig:
"""Configuration for Weights & Biases experiment tracking."""

mode: str = "disabled"
mode: str = field(
default="disabled",
metadata={
"help": "Tracking mode. One of 'online', 'offline', 'disabled', or 'shared'.",
"choices": ["online", "offline", "disabled", "shared"],
},
)
wandb_base_url: str = ""
wandb_api_key: str = ""
entity: str | None = None
Expand All @@ -1981,6 +1987,14 @@ class WandBConfig:
config: dict | None = None
id_suffix: str | None = "train"

def __post_init__(self):
"""Validate WandB configuration."""
valid_modes = ("online", "offline", "disabled", "shared")
if self.mode not in valid_modes:
raise ValueError(
f"Invalid wandb mode: '{self.mode}'. Must be one of: {', '.join(valid_modes)}."
)


@dataclass
class SwanlabConfig:
Expand All @@ -1990,11 +2004,23 @@ class SwanlabConfig:
name: str | None = None
config: dict | None = None
logdir: str | None = None
mode: str | None = "disabled"
mode: str = field(
default="disabled",
metadata={
"help": "Tracking mode. One of 'cloud', 'local', 'disabled', or 'offline'.",
"choices": ["cloud", "local", "disabled", "offline"],
},
)
# set None to prevent info-leak in docs
api_key: str | None = None

def __post_init__(self):
"""Validate SwanLab configuration."""
valid_modes = ("cloud", "local", "disabled", "offline")
if self.mode not in valid_modes:
raise ValueError(
f"Invalid swanlab mode: '{self.mode}'. Must be one of: {', '.join(valid_modes)}."
)
if self.api_key is None:
self.api_key = os.getenv("SWANLAB_API_KEY")

Expand Down
8 changes: 8 additions & 0 deletions areal/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,19 @@ def from_megatron_xccl(
cls,
gen_allocation: ModelAllocation,
weight_chunked_mem_mb: int = 1024,
use_lora: bool = False,
lora_name: str = "",
lora_int_id: int = 1,
base_model_name: str = "",
):
return cls(
type="xccl",
gen_allocation=gen_allocation,
weight_chunked_mem_mb=weight_chunked_mem_mb,
use_lora=use_lora,
lora_name=lora_name,
lora_int_id=lora_int_id,
base_model_name=base_model_name,
)

@classmethod
Expand Down
132 changes: 116 additions & 16 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import torch.distributed as dist
from megatron.bridge import AutoBridge as MegatronBridgeAutoBridge
from megatron.bridge.peft.lora import LoRA as MegatronBridgeLoRA
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
Expand Down Expand Up @@ -59,6 +60,7 @@
get_named_parameters,
remove_padding,
)
from areal.engine.megatron_utils.megatron_lora import get_vllm_lora_target_modules
from areal.engine.megatron_utils.packed_context_parallel import (
packed_context_parallel_forward,
)
Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(self, config: TrainEngineConfig):
)
self.quantization_config: dict[str, int | str | list[str]] | None = None
self.bridge_cls: str = getattr(self.mcore_config, "bridge_type", "mbridge")
self.bridge_lora: MegatronBridgeLoRA | None = None

def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
if parallel_strategy is None:
Expand Down Expand Up @@ -210,6 +213,42 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None
)
self.process_group_initialized = True

def _apply_megatron_bridge_lora(self) -> None:
assert self.model is not None, "Model must be initialized before applying LoRA."
assert self.bridge_cls == "megatron-bridge"

target_modules = list(self.config.target_modules or [])
if not target_modules or "all-linear" in target_modules:
# Expand all-linear to explicit Megatron-Bridge linear module targets.
target_modules = [
"linear_qkv",
"linear_proj",
"linear_fc1",
"linear_fc2",
]
self.bridge_lora = MegatronBridgeLoRA(
target_modules=target_modules,
dim=self.config.lora_rank,
alpha=self.config.lora_alpha,
dropout=0.0,
)
self.model = _MegatronModelList(self.bridge_lora(self.model, training=True))
self.bridge_lora.set_params_to_save(self.model)

total_params = sum(param.numel() for param in self.model.parameters())
trainable_params = sum(
param.numel() for param in self.model.parameters() if param.requires_grad
)
self.logger.info(
"Applied Megatron Bridge LoRA: target_modules=%s, rank=%s, alpha=%s, trainable=%s/%s (%.4f%%)",
target_modules,
self.config.lora_rank,
self.config.lora_alpha,
trainable_params,
total_params,
100.0 * trainable_params / max(total_params, 1),
)

def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
try:
self.seed = get_seed()
Expand Down Expand Up @@ -238,6 +277,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
)
self.engine_lock = DistributedLock("train_engine_lock")

if self.config.use_lora and self.bridge_cls != "megatron-bridge":
raise NotImplementedError(
"MegatronEngine LoRA POC currently only supports bridge_type='megatron-bridge'. "
"mbridge does not support LoRA in this path."
)

self.tokenizer = load_hf_tokenizer(self.config.path)

with patch_bridge_for_tree_training(
Expand Down Expand Up @@ -270,10 +315,14 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
bridge=self.bridge,
bridge_type=self.bridge_cls,
is_critic=self.config.is_critic,
use_lora=self.config.use_lora,
)

self.model = _MegatronModelList(models)

if self.config.use_lora:
self._apply_megatron_bridge_lora()

with self.device:
self._load_model_from_hf(self.config.path)

Expand Down Expand Up @@ -552,6 +601,12 @@ def save(self, meta: SaveLoadMeta):
base_model_path=meta.base_model_path,
)
elif meta.weight_format == "dcp":
if self.checkpointer is None:
raise NotImplementedError(
"DCP checkpoint save is not available for this Megatron configuration "
"(e.g., LoRA path without distributed optimizer support). "
"Please use weight_format='hf' for adapter/full-model export."
)
self.checkpointer.save_checkpoint(meta.path, with_optimizer=meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
Expand All @@ -564,6 +619,12 @@ def load(self, meta: SaveLoadMeta):
)
self._load_model_from_hf(meta.path)
elif meta.weight_format == "dcp":
if self.checkpointer is None:
raise NotImplementedError(
"DCP checkpoint load is not available for this Megatron configuration "
"(e.g., LoRA path without distributed optimizer support). "
"Please use weight_format='hf' for adapter/full-model load."
)
self.checkpointer.load_checkpoint(meta.path, with_optimizer=meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
Expand Down Expand Up @@ -967,6 +1028,12 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
return
assert self.model is not None and len(self.model) > 0

use_distributed_optimizer = (
False
if self.config.use_lora
else self.mcore_config.ddp.use_distributed_optimizer
)

assert self.optimizer_config.type in [
"adam",
"sgd",
Expand All @@ -987,7 +1054,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
adam_beta1=self.optimizer_config.beta1,
adam_beta2=self.optimizer_config.beta2,
adam_eps=self.optimizer_config.eps,
use_distributed_optimizer=self.mcore_config.ddp.use_distributed_optimizer,
use_distributed_optimizer=use_distributed_optimizer,
params_dtype=self.dtype,
clip_grad=self.optimizer_config.gradient_clipping,
fp8_recipe=(self.fp8_config.recipe if self.enable_fp8 else None),
Expand Down Expand Up @@ -1028,14 +1095,16 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
)
self.lr_scheduler = lr_scheduler

self.checkpointer = MegatronCheckpointManager(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
use_distributed_optimizer=self.mcore_config.ddp.use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler,
async_save=self.mcore_config.async_save,
)
# MegatronCheckpointManager now only support distributed optimizer which lora does not support
if not self.config.use_lora:
self.checkpointer = MegatronCheckpointManager(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
use_distributed_optimizer=use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler,
async_save=self.mcore_config.async_save,
)

def _check_rollout_engine_connected(self) -> None:
"""Validate that rollout engine has been connected via connect_engine()."""
Expand Down Expand Up @@ -1072,6 +1141,16 @@ def _update_bucket_weights_from_distributed(
for name, tensor in converted_named_tensors
]

if self.config.use_lora:
meta.peft_config = {
"r": self.config.lora_rank,
"lora_alpha": self.config.lora_alpha,
"target_modules": get_vllm_lora_target_modules(
list(self.config.target_modules or [])
),
"bias": "none",
}

fut = self.rollout_engine.update_weights_from_distributed(meta, param_specs)

handles = []
Expand Down Expand Up @@ -1160,10 +1239,14 @@ def _impl_update_weight_from_distributed(
self._update_bucket_weights_from_distributed(meta, converted_named_tensors)
buffer_size = 0

model_name = self.hf_config.model_type
if self.config.use_lora:
model_name = f"{model_name}_lora"

converted_named_tensors.extend(
convert_to_hf(
self.tf_config,
self.hf_config.model_type,
model_name,
name,
param,
quantization_config=self.quantization_config,
Expand Down Expand Up @@ -1324,6 +1407,10 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:
for name, param in get_named_parameters(self.model, num_moe_experts):
if ".experts." in name:
continue
if self.config.use_lora and (
".adapter." not in name or not getattr(param, "requires_grad", False)
):
continue
buffer_size = self._impl_update_weight_from_distributed(
meta,
name,
Expand All @@ -1336,14 +1423,19 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:
# Only pipeline parallel heads CAN contain named tensors here
if converted_named_tensors:
self._update_bucket_weights_from_distributed(meta, converted_named_tensors)
elif self.is_pipeline_parallel_head() and not self.config.use_lora:
self.logger.warning(
"No tensors were collected for distributed update at version %s.",
meta.version,
)

dist.barrier(group=self.cpu_group)

buffer_size = 0
named_tensors = []

for name, param in get_named_parameters(self.model, num_moe_experts):
if ".experts." not in name:
if ".experts." not in name or self.config.use_lora:
continue
buffer_size = self._impl_update_expert_weight_from_distributed(
meta,
Expand Down Expand Up @@ -1406,11 +1498,19 @@ def _save_model_to_hf(
raise ValueError(
"Saving critic model is not supported with megatron-bridge."
)
self.bridge.save_hf_pretrained(
self.model,
path,
source_path=base_model_path,
)
if self.config.use_lora:
self.bridge.save_hf_adapter(
self.model,
path=path,
peft_config=self.bridge_lora,
base_model_name_or_path=base_model_path or self.config.path,
)
else:
self.bridge.save_hf_pretrained(
self.model,
path,
source_path=base_model_path,
)
else:
save_weights_to_hf_with_mbridge_fast(
bridge=self.bridge,
Expand Down
11 changes: 8 additions & 3 deletions areal/engine/megatron_utils/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_block_size_from_config,
quantize_params,
)
from areal.engine.megatron_utils.megatron_lora import convert_qwen3_lora_to_hf


def _all_gather_and_concat(
Expand Down Expand Up @@ -96,11 +97,13 @@ def all_gather_param(
if "expert_bias" in name:
return param

if not hasattr(param, "tensor_model_parallel"):
raise ValueError(f"{name} does not have tensor_model_parallel attribute")

param_is_fp8 = is_float8tensor(param)

if not hasattr(param, "tensor_model_parallel"):
if param_is_fp8 and fp8_direct_convert:
return param
return param.data

# Check if this param is truly NOT TP-sharded.
# NOTE: TE unconditionally sets tensor_model_parallel=True on all Linear
# weights, even for modules with parallel_mode='duplicated'. The original
Expand Down Expand Up @@ -756,6 +759,8 @@ def convert_bailingmoe_to_hf(
# Adapted from slime
# A registry for conversion functions is more extensible.
_CONVERSION_FN_REGISTRY = {
"qwen3_lora": convert_qwen3_lora_to_hf,
"qwen2_lora": convert_qwen3_lora_to_hf,
"qwen3_moe": convert_qwen3moe_to_hf,
"qwen2": convert_qwen2_to_hf,
"qwen3": convert_qwen2_to_hf,
Expand Down
Loading
Loading