Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
/areal/api/ @garrett4wade
/areal/engine/ @rchardx
/areal/experimental/inference_service @nuzant
/areal/experimental/agent_service/ @CormickKneey
/areal/infra/ @garrett4wade
/areal/models/ @rchardx
/areal/trainer/ @garrett4wade
Expand Down
1 change: 1 addition & 0 deletions GOVERNANCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ project.
| Zhiyu Mei | AReaL Team, Ant Group | @nuzant |
| Xujie Shen | AReaL Team, Ant Group | @fishcrap |
| Tongkai Yang | AReaL Team, Ant Group | @fredy12 |
| Han Jiang | AReaL Team, Ant Group | @CormickKneey |

### Lead Maintainer (BDFL)

Expand Down
4 changes: 2 additions & 2 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:
converted_named_tensors = []

for name, param in get_named_parameters(self.model, num_moe_experts):
if ".experts." in name:
if ".experts." in name and not self.config.use_lora:
continue
if self.config.use_lora and (
".adapter." not in name or not getattr(param, "requires_grad", False)
Expand All @@ -1474,7 +1474,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:
# Only pipeline parallel heads CAN contain named tensors here
if converted_named_tensors:
self._update_bucket_weights_from_distributed(meta, converted_named_tensors)
elif self.is_pipeline_parallel_head() and not self.config.use_lora:
elif self.config.use_lora and self.is_pipeline_parallel_head():
self.logger.warning(
"No tensors were collected for distributed update at version %s.",
meta.version,
Expand Down
6 changes: 5 additions & 1 deletion areal/engine/megatron_utils/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
get_block_size_from_config,
quantize_params,
)
from areal.engine.megatron_utils.megatron_lora import convert_qwen3_lora_to_hf
from areal.engine.megatron_utils.megatron_lora import (
convert_qwen3_lora_to_hf,
convert_qwen3_moe_lora_to_hf,
)


def _all_gather_and_concat(
Expand Down Expand Up @@ -763,6 +766,7 @@ def convert_bailingmoe_to_hf(
_CONVERSION_FN_REGISTRY = {
"qwen3_lora": convert_qwen3_lora_to_hf,
"qwen2_lora": convert_qwen3_lora_to_hf,
"qwen3_moe_lora": convert_qwen3_moe_lora_to_hf,
"qwen3_moe": convert_qwen3moe_to_hf,
"qwen2": convert_qwen2_to_hf,
"qwen3": convert_qwen2_to_hf,
Expand Down
102 changes: 101 additions & 1 deletion areal/engine/megatron_utils/megatron_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,103 @@ def convert_qwen3_lora_to_hf(
return []


def convert_qwen3_moe_lora_to_hf(
tf_config,
name: str,
tensor: torch.Tensor,
) -> list[tuple[str, torch.Tensor]]:
# Reuse non-MoE conversion for attention and dense MLP paths.
converted = convert_qwen3_lora_to_hf(tf_config, name, tensor)
if converted:
return converted

grouped_expert_pattern = (
r"(?:^|.*\.)decoder\.layers\.(\d+)\.mlp\.experts\."
r"(linear_fc1|linear_fc2)\.adapter\.(linear_in|linear_out)\.weight$"
)
match = re.match(grouped_expert_pattern, name)
if match is not None:
layer_idx, module_name, adapter_part = match.groups()
num_experts = getattr(tf_config, "num_moe_experts", None)
if num_experts is None:
num_experts = getattr(tf_config, "num_experts", None)
if num_experts is None:
return []

outputs: list[tuple[str, torch.Tensor]] = []
for expert_idx in range(num_experts):
base_prefix = (
f"base_model.model.model.layers.{layer_idx}.mlp.experts.{expert_idx}"
)

if module_name == "linear_fc2":
hf_base = f"{base_prefix}.down_proj"
suffix = (
"lora_A.default.weight"
if adapter_part == "linear_in"
else "lora_B.default.weight"
)
outputs.append((f"{hf_base}.{suffix}", tensor))
continue

gate_base = f"{base_prefix}.gate_proj"
up_base = f"{base_prefix}.up_proj"
if adapter_part == "linear_in":
outputs.extend(
[
(f"{gate_base}.lora_A.default.weight", tensor),
(f"{up_base}.lora_A.default.weight", tensor),
]
)
continue

gate_b, up_b = tensor.chunk(2, dim=0)
outputs.extend(
[
(f"{gate_base}.lora_B.default.weight", gate_b.contiguous()),
(f"{up_base}.lora_B.default.weight", up_b.contiguous()),
]
)

return outputs

expert_pattern = (
r"(?:^|.*\.)decoder\.layers\.(\d+)\.mlp\.experts\."
r"(linear_fc1|linear_fc2)\.adapter\.(linear_in|linear_out)\.weight(\d+)$"
)
match = re.match(expert_pattern, name)
if match is None:
return []

layer_idx, module_name, adapter_part, expert_idx = match.groups()
base_prefix = f"base_model.model.model.layers.{layer_idx}.mlp.experts.{expert_idx}"

if module_name == "linear_fc2":
hf_base = f"{base_prefix}.down_proj"
suffix = (
"lora_A.default.weight"
if adapter_part == "linear_in"
else "lora_B.default.weight"
)
return [(f"{hf_base}.{suffix}", tensor)]

if module_name == "linear_fc1":
gate_base = f"{base_prefix}.gate_proj"
up_base = f"{base_prefix}.up_proj"
if adapter_part == "linear_in":
return [
(f"{gate_base}.lora_A.default.weight", tensor),
(f"{up_base}.lora_A.default.weight", tensor),
]
gate_b, up_b = tensor.chunk(2, dim=0)
return [
(f"{gate_base}.lora_B.default.weight", gate_b.contiguous()),
(f"{up_base}.lora_B.default.weight", up_b.contiguous()),
]

return []


def _infer_target_modules_from_adapter_weights(weight_keys: Iterable[str]) -> list[str]:
"""
Infer PEFT target_modules from adapter weight parameter names.
Expand Down Expand Up @@ -235,7 +332,10 @@ def save_hf_adapter(
# Export adapter weights
adapter_state: dict[str, torch.Tensor] = {}
for name, tensor in self.export_adapter_weights(
model, cpu=True, show_progress=show_progress
# cpu=True may reduce memory pressure but hangs for MoE models using slurm
model,
cpu=False,
show_progress=False,
):
adapter_state[f"base_model.model.{name}"] = tensor.clone().float()

Expand Down
4 changes: 2 additions & 2 deletions areal/engine/vllm_ext/vllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def areal_update_weight_lora_xccl(self):
async_op=False,
)

received_weights[name] = tensor
received_weights[name] = tensor.cpu()

logger.info(f"Received {len(received_weights)} LoRA parameters via XCCL")

Expand Down Expand Up @@ -259,7 +259,7 @@ def areal_update_weight_lora_xccl(self):
lora_model_id=self.areal_lora_int_id,
tensors=merged_weights,
peft_helper=peft_helper,
device=self.model_runner.device,
device="cpu",
dtype=self.model_runner.lora_manager.lora_config.lora_dtype,
model_vocab_size=model_vocab_size,
weights_mapper=getattr(
Expand Down
7 changes: 6 additions & 1 deletion areal/infra/rpc/ray_rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def _get_device(self):

return current_platform.current_device()

def _get_device_type(self) -> str:
from areal.infra.platforms import current_platform

return current_platform.device_type

def _should_broadcast_payload(
self,
engine: TrainEngine | InferenceEngine,
Expand Down Expand Up @@ -181,7 +186,7 @@ def call(
if (
isinstance(engine, TrainEngine)
and engine.initialized
and self._get_device().type != "cpu"
and self._get_device_type() != "cpu"
):
from areal.infra.platforms import current_platform

Expand Down
37 changes: 37 additions & 0 deletions areal/infra/rpc/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,26 @@ def to_image(self) -> "ImageObject":
return image


class SerializedRayObjectRef(BaseModel):
"""Pydantic model for serialized ray.ObjectRef handles."""

type: Literal["ray_object_ref"] = Field(default="ray_object_ref")
data: str

@classmethod
def from_object_ref(cls, ref: Any) -> "SerializedRayObjectRef":
import ray.cloudpickle

payload = ray.cloudpickle.dumps(ref)
return cls(data=base64.b64encode(payload).decode("utf-8"))

def to_object_ref(self) -> Any:
import ray.cloudpickle

payload = base64.b64decode(self.data.encode("utf-8"))
return ray.cloudpickle.loads(payload)


class SerializedDataclass(BaseModel):
"""Pydantic model for serialized dataclass with metadata.

Expand Down Expand Up @@ -569,6 +589,13 @@ def serialize_value(value: Any) -> Any:
if ImageObject is not None and isinstance(value, ImageObject):
return SerializedPILImage.from_image(value).model_dump()

# Handle Ray object references when HTTP RPC needs to carry RTensor shard
# handles across processes.
import ray

if isinstance(value, ray.ObjectRef):
return SerializedRayObjectRef.from_object_ref(value).model_dump()

# Handle dataclass instances (check before dict, as dataclasses can be dict-like)
# Note: is_dataclass returns True for both classes and instances, so check it's not a type
if is_dataclass(value) and not isinstance(value, type):
Expand Down Expand Up @@ -698,6 +725,16 @@ def deserialize_value(value: Any) -> Any:
f"Failed to deserialize PIL image, treating as regular dict: {e}"
)

# Check for SerializedRayObjectRef marker
if value.get("type") == "ray_object_ref":
try:
serialized_ref = SerializedRayObjectRef.model_validate(value)
return serialized_ref.to_object_ref()
except Exception as e:
logger.warning(
f"Failed to deserialize ray.ObjectRef, treating as regular dict: {e}"
)

# Check for SerializedTensor marker
if value.get("type") == "tensor":
try:
Expand Down
17 changes: 11 additions & 6 deletions docs/en/reference/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ The current LoRA support matrix in AReaL is:
| Megatron | ✅ | ❌ |
| Archon | ❌ | ❌ |

Example scripts:
**Example scripts:**

| Engine | Example script |
| -------- | --------------------------------------------- |
| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` |
| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` |
| Engine | Example script |
| ------------ | ------------------------------------------------- |
| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` |
| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` |
| Megatron MoE | `examples/math/gsm8k_grpo_megatron_lora_moe.yaml` |

For Megatron + vLLM, AReaL now supports:

- LoRA fine-tuning on MoE architectures such as Qwen3 MoE with XCCL-based LoRA weight.
- Cross-node LoRA training when the Megatron and rollout groups span multiple nodes.

## Core LoRA Parameters

Expand All @@ -50,5 +56,4 @@ Example scripts:

- Start with `r=16` or `r=32` for most models, then tune upward only if needed.
- Keep `target_modules` consistent with your model architecture naming.
- Currently only dense models (non MoE) are supported.
- For Megatron backend, LoRA requires `megatron-bridge` instead of `mbridge`.
21 changes: 13 additions & 8 deletions docs/zh/reference/lora.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# LoRA 参考

LoRA 是一种参数高效的微调技术,会在预训练权重中注入可训练的低秩矩阵, 通常作用在线性层附近。与全参数微调相比,LoRA 可以显著降低显存占用和 计算开销,从而让大模型的
LoRA 是一种参数高效的微调技术,会在预训练权重中注入可训练的低秩矩阵, 通常作用在线性层附近。与全参数微调相比,LoRA 可以显著降低显存占用和计算开销, 从而让大模型的
RL 微调在硬件资源有限的条件下也更具可行性。

在 AReaL 中,LoRA 尤其适用于以下场景:

- 在相对有限的硬件条件下进行超大模型的强化学习训练,例如使用 8 x 80 GB GPU 训练 70B+ 规模模型,
- 由于显存压力更低,可以支持更大的 batch size,
- 模型迁移与部署更加简单,因为只需要保存和分发 LoRA adapter,
- \[Future\] 更高效地并行微调多个 LoRA adapter,以提升硬件利用率 (参见 RFC
- \[Future\] 更高效地并行微调多个 LoRA adapter,以提升硬件利用率(参见 RFC
[#609](https://github.com/inclusionAI/AReaL/issues/609))。

本文档说明如何在 RL 训练中启用 LoRA,并配置相关参数。
Expand All @@ -23,12 +23,18 @@ AReaL 当前的 LoRA 支持矩阵如下:
| Megatron | ✅ | ❌ |
| Archon | ❌ | ❌ |

示例脚本:
**示例脚本:**

| Engine | Example script |
| -------- | --------------------------------------------- |
| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` |
| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` |
| Engine | Example script |
| ------------ | ------------------------------------------------- |
| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` |
| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` |
| Megatron MoE | `examples/math/gsm8k_grpo_megatron_lora_moe.yaml` |

对于 Megatron + vLLM,AReaL 现在支持:

- 在 Qwen3 MoE 等 MoE 架构上进行 LoRA 微调,并通过 XCCL 更新 LoRA 权重。
- 当 Megatron 与 rollout group 横跨多个节点时进行跨节点 LoRA 训练。

## 核心 LoRA 参数

Expand All @@ -44,5 +50,4 @@ AReaL 当前的 LoRA 支持矩阵如下:

- 可先从 `r=16` 或 `r=32` 开始,再按效果和资源逐步调参。
- `target_modules` 需与具体模型的层命名保持一致。
- 当前仅支持 dense 模型(非 MoE)。
- 对于 Megatron 后端,LoRA 需要使用 `megatron-bridge`,而不是 `mbridge`。
Loading
Loading