diff --git a/README.md b/README.md index 92a714e031..2956718233 100644 --- a/README.md +++ b/README.md @@ -202,11 +202,11 @@ Check the [AI Coding Assistant Guide](docs/reference/ai_assisted_dev.md) and ### Training Backends -| Backend | DP | Tensor Parallel | Sequence Parallel within TP | Context Parallel | Pipeline Parallel | Expert Parallel | 1D Sequence Packing | LoRA | -| ------------------ | ----------- | --------------- | --------------------------- | ---------------- | ----------------- | --------------- | ------------------- | ---- | -| **Megatron** | ✅ (ZeRO-1) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| **PyTorch FSDP** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | -| **PyTorch Archon** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| Backend | DP | Tensor Parallel | Sequence Parallel within TP | Context Parallel | Pipeline Parallel | Expert Parallel | 1D Sequence Packing | LoRA | +| ------------------ | ----------- | --------------- | --------------------------- | ---------------- | ----------------- | --------------- | ------------------- | -------------------------------- | +| **Megatron** | ✅ (ZeRO-1) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (with vLLM inference backend) | +| **PyTorch FSDP** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | +| **PyTorch Archon** | ✅ (FSDP2) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ### Inference Backends @@ -252,6 +252,7 @@ Check the [AI Coding Assistant Guide](docs/reference/ai_assisted_dev.md) and ### Reference - [CLI Configurations](docs/en/cli_reference.md) +- [LoRA RL](docs/en/reference/lora.md) - [Checkpointing](docs/en/reference/checkpointing.md) - [Metrics Tracking](docs/en/reference/metrics_tracking.md) - [Allocation Mode](docs/en/reference/alloc_mode.md) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 7701dab081..e08c852ec4 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1968,7 +1968,13 @@ def __post_init__(self): class WandBConfig: """Configuration for Weights & Biases experiment tracking.""" - mode: str = "disabled" + mode: str = field( + default="disabled", + metadata={ + "help": "Tracking mode. One of 'online', 'offline', 'disabled', or 'shared'.", + "choices": ["online", "offline", "disabled", "shared"], + }, + ) wandb_base_url: str = "" wandb_api_key: str = "" entity: str | None = None @@ -1981,6 +1987,14 @@ class WandBConfig: config: dict | None = None id_suffix: str | None = "train" + def __post_init__(self): + """Validate WandB configuration.""" + valid_modes = ("online", "offline", "disabled", "shared") + if self.mode not in valid_modes: + raise ValueError( + f"Invalid wandb mode: '{self.mode}'. Must be one of: {', '.join(valid_modes)}." + ) + @dataclass class SwanlabConfig: @@ -1990,11 +2004,23 @@ class SwanlabConfig: name: str | None = None config: dict | None = None logdir: str | None = None - mode: str | None = "disabled" + mode: str = field( + default="disabled", + metadata={ + "help": "Tracking mode. One of 'cloud', 'local', 'disabled', or 'offline'.", + "choices": ["cloud", "local", "disabled", "offline"], + }, + ) # set None to prevent info-leak in docs api_key: str | None = None def __post_init__(self): + """Validate SwanLab configuration.""" + valid_modes = ("cloud", "local", "disabled", "offline") + if self.mode not in valid_modes: + raise ValueError( + f"Invalid swanlab mode: '{self.mode}'. Must be one of: {', '.join(valid_modes)}." + ) if self.api_key is None: self.api_key = os.getenv("SWANLAB_API_KEY") diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index b9f5dd24f4..e63f849230 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -230,11 +230,19 @@ def from_megatron_xccl( cls, gen_allocation: ModelAllocation, weight_chunked_mem_mb: int = 1024, + use_lora: bool = False, + lora_name: str = "", + lora_int_id: int = 1, + base_model_name: str = "", ): return cls( type="xccl", gen_allocation=gen_allocation, weight_chunked_mem_mb=weight_chunked_mem_mb, + use_lora=use_lora, + lora_name=lora_name, + lora_int_id=lora_int_id, + base_model_name=base_model_name, ) @classmethod diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 54bdb8a8b8..aabe64b44d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -15,6 +15,7 @@ import torch import torch.distributed as dist from megatron.bridge import AutoBridge as MegatronBridgeAutoBridge +from megatron.bridge.peft.lora import LoRA as MegatronBridgeLoRA from megatron.core import parallel_state as mpu from megatron.core import tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP @@ -59,6 +60,7 @@ get_named_parameters, remove_padding, ) +from areal.engine.megatron_utils.megatron_lora import get_vllm_lora_target_modules from areal.engine.megatron_utils.packed_context_parallel import ( packed_context_parallel_forward, ) @@ -169,6 +171,7 @@ def __init__(self, config: TrainEngineConfig): ) self.quantization_config: dict[str, int | str | list[str]] | None = None self.bridge_cls: str = getattr(self.mcore_config, "bridge_type", "mbridge") + self.bridge_lora: MegatronBridgeLoRA | None = None def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): if parallel_strategy is None: @@ -210,6 +213,42 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None ) self.process_group_initialized = True + def _apply_megatron_bridge_lora(self) -> None: + assert self.model is not None, "Model must be initialized before applying LoRA." + assert self.bridge_cls == "megatron-bridge" + + target_modules = list(self.config.target_modules or []) + if not target_modules or "all-linear" in target_modules: + # Expand all-linear to explicit Megatron-Bridge linear module targets. + target_modules = [ + "linear_qkv", + "linear_proj", + "linear_fc1", + "linear_fc2", + ] + self.bridge_lora = MegatronBridgeLoRA( + target_modules=target_modules, + dim=self.config.lora_rank, + alpha=self.config.lora_alpha, + dropout=0.0, + ) + self.model = _MegatronModelList(self.bridge_lora(self.model, training=True)) + self.bridge_lora.set_params_to_save(self.model) + + total_params = sum(param.numel() for param in self.model.parameters()) + trainable_params = sum( + param.numel() for param in self.model.parameters() if param.requires_grad + ) + self.logger.info( + "Applied Megatron Bridge LoRA: target_modules=%s, rank=%s, alpha=%s, trainable=%s/%s (%.4f%%)", + target_modules, + self.config.lora_rank, + self.config.lora_alpha, + trainable_params, + total_params, + 100.0 * trainable_params / max(total_params, 1), + ) + def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): try: self.seed = get_seed() @@ -238,6 +277,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): ) self.engine_lock = DistributedLock("train_engine_lock") + if self.config.use_lora and self.bridge_cls != "megatron-bridge": + raise NotImplementedError( + "MegatronEngine LoRA POC currently only supports bridge_type='megatron-bridge'. " + "mbridge does not support LoRA in this path." + ) + self.tokenizer = load_hf_tokenizer(self.config.path) with patch_bridge_for_tree_training( @@ -270,10 +315,14 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): bridge=self.bridge, bridge_type=self.bridge_cls, is_critic=self.config.is_critic, + use_lora=self.config.use_lora, ) self.model = _MegatronModelList(models) + if self.config.use_lora: + self._apply_megatron_bridge_lora() + with self.device: self._load_model_from_hf(self.config.path) @@ -552,6 +601,12 @@ def save(self, meta: SaveLoadMeta): base_model_path=meta.base_model_path, ) elif meta.weight_format == "dcp": + if self.checkpointer is None: + raise NotImplementedError( + "DCP checkpoint save is not available for this Megatron configuration " + "(e.g., LoRA path without distributed optimizer support). " + "Please use weight_format='hf' for adapter/full-model export." + ) self.checkpointer.save_checkpoint(meta.path, with_optimizer=meta.with_optim) else: raise ValueError(f"Unknown weight format {meta.weight_format}. ") @@ -564,6 +619,12 @@ def load(self, meta: SaveLoadMeta): ) self._load_model_from_hf(meta.path) elif meta.weight_format == "dcp": + if self.checkpointer is None: + raise NotImplementedError( + "DCP checkpoint load is not available for this Megatron configuration " + "(e.g., LoRA path without distributed optimizer support). " + "Please use weight_format='hf' for adapter/full-model load." + ) self.checkpointer.load_checkpoint(meta.path, with_optimizer=meta.with_optim) else: raise ValueError(f"Unknown weight format {meta.weight_format}. ") @@ -967,6 +1028,12 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: return assert self.model is not None and len(self.model) > 0 + use_distributed_optimizer = ( + False + if self.config.use_lora + else self.mcore_config.ddp.use_distributed_optimizer + ) + assert self.optimizer_config.type in [ "adam", "sgd", @@ -987,7 +1054,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: adam_beta1=self.optimizer_config.beta1, adam_beta2=self.optimizer_config.beta2, adam_eps=self.optimizer_config.eps, - use_distributed_optimizer=self.mcore_config.ddp.use_distributed_optimizer, + use_distributed_optimizer=use_distributed_optimizer, params_dtype=self.dtype, clip_grad=self.optimizer_config.gradient_clipping, fp8_recipe=(self.fp8_config.recipe if self.enable_fp8 else None), @@ -1028,14 +1095,16 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: ) self.lr_scheduler = lr_scheduler - self.checkpointer = MegatronCheckpointManager( - model=self.model, - optimizer=self.optimizer, - lr_scheduler=self.lr_scheduler, - use_distributed_optimizer=self.mcore_config.ddp.use_distributed_optimizer, - use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler, - async_save=self.mcore_config.async_save, - ) + # MegatronCheckpointManager now only support distributed optimizer which lora does not support + if not self.config.use_lora: + self.checkpointer = MegatronCheckpointManager( + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + use_distributed_optimizer=use_distributed_optimizer, + use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler, + async_save=self.mcore_config.async_save, + ) def _check_rollout_engine_connected(self) -> None: """Validate that rollout engine has been connected via connect_engine().""" @@ -1072,6 +1141,16 @@ def _update_bucket_weights_from_distributed( for name, tensor in converted_named_tensors ] + if self.config.use_lora: + meta.peft_config = { + "r": self.config.lora_rank, + "lora_alpha": self.config.lora_alpha, + "target_modules": get_vllm_lora_target_modules( + list(self.config.target_modules or []) + ), + "bias": "none", + } + fut = self.rollout_engine.update_weights_from_distributed(meta, param_specs) handles = [] @@ -1160,10 +1239,14 @@ def _impl_update_weight_from_distributed( self._update_bucket_weights_from_distributed(meta, converted_named_tensors) buffer_size = 0 + model_name = self.hf_config.model_type + if self.config.use_lora: + model_name = f"{model_name}_lora" + converted_named_tensors.extend( convert_to_hf( self.tf_config, - self.hf_config.model_type, + model_name, name, param, quantization_config=self.quantization_config, @@ -1324,6 +1407,10 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: for name, param in get_named_parameters(self.model, num_moe_experts): if ".experts." in name: continue + if self.config.use_lora and ( + ".adapter." not in name or not getattr(param, "requires_grad", False) + ): + continue buffer_size = self._impl_update_weight_from_distributed( meta, name, @@ -1336,6 +1423,11 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: # Only pipeline parallel heads CAN contain named tensors here if converted_named_tensors: self._update_bucket_weights_from_distributed(meta, converted_named_tensors) + elif self.is_pipeline_parallel_head() and not self.config.use_lora: + self.logger.warning( + "No tensors were collected for distributed update at version %s.", + meta.version, + ) dist.barrier(group=self.cpu_group) @@ -1343,7 +1435,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: named_tensors = [] for name, param in get_named_parameters(self.model, num_moe_experts): - if ".experts." not in name: + if ".experts." not in name or self.config.use_lora: continue buffer_size = self._impl_update_expert_weight_from_distributed( meta, @@ -1406,11 +1498,19 @@ def _save_model_to_hf( raise ValueError( "Saving critic model is not supported with megatron-bridge." ) - self.bridge.save_hf_pretrained( - self.model, - path, - source_path=base_model_path, - ) + if self.config.use_lora: + self.bridge.save_hf_adapter( + self.model, + path=path, + peft_config=self.bridge_lora, + base_model_name_or_path=base_model_path or self.config.path, + ) + else: + self.bridge.save_hf_pretrained( + self.model, + path, + source_path=base_model_path, + ) else: save_weights_to_hf_with_mbridge_fast( bridge=self.bridge, diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 5a57bbf5fa..85b445b0ac 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -15,6 +15,7 @@ get_block_size_from_config, quantize_params, ) +from areal.engine.megatron_utils.megatron_lora import convert_qwen3_lora_to_hf def _all_gather_and_concat( @@ -96,11 +97,13 @@ def all_gather_param( if "expert_bias" in name: return param - if not hasattr(param, "tensor_model_parallel"): - raise ValueError(f"{name} does not have tensor_model_parallel attribute") - param_is_fp8 = is_float8tensor(param) + if not hasattr(param, "tensor_model_parallel"): + if param_is_fp8 and fp8_direct_convert: + return param + return param.data + # Check if this param is truly NOT TP-sharded. # NOTE: TE unconditionally sets tensor_model_parallel=True on all Linear # weights, even for modules with parallel_mode='duplicated'. The original @@ -756,6 +759,8 @@ def convert_bailingmoe_to_hf( # Adapted from slime # A registry for conversion functions is more extensible. _CONVERSION_FN_REGISTRY = { + "qwen3_lora": convert_qwen3_lora_to_hf, + "qwen2_lora": convert_qwen3_lora_to_hf, "qwen3_moe": convert_qwen3moe_to_hf, "qwen2": convert_qwen2_to_hf, "qwen3": convert_qwen2_to_hf, diff --git a/areal/engine/megatron_utils/megatron_lora.py b/areal/engine/megatron_utils/megatron_lora.py new file mode 100644 index 0000000000..23f5299e4b --- /dev/null +++ b/areal/engine/megatron_utils/megatron_lora.py @@ -0,0 +1,296 @@ +import re +from collections.abc import Iterable +from pathlib import Path + +import torch +import torch.distributed as dist + + +def get_vllm_lora_target_modules(target_modules: list[str]) -> list[str]: + if not target_modules or "all-linear" in target_modules: + target_modules = [ + "linear_qkv", + "linear_proj", + "linear_fc1", + "linear_fc2", + ] + + bridge_to_vllm_targets = { + "linear_qkv": ["q_proj", "k_proj", "v_proj"], + "linear_proj": ["o_proj"], + "linear_fc1": ["gate_proj", "up_proj"], + "linear_fc2": ["down_proj"], + } + targets: list[str] = [] + for module_name in target_modules: + mapped = bridge_to_vllm_targets.get(module_name) + if mapped is None: + raise NotImplementedError( + f"LoRA target module '{module_name}' is not supported in MegatronEngine yet." + ) + targets.extend(mapped) + return sorted(set(targets)) + + +def convert_qwen3_lora_to_hf( + tf_config, + name: str, + tensor: torch.Tensor, +) -> list[tuple[str, torch.Tensor]]: + pattern = ( + r"(?:^|.*\.)decoder\.layers\.(\d+)\." + r"(self_attention\.linear_qkv|self_attention\.linear_proj|mlp\.linear_fc1|mlp\.linear_fc2)\." + r"adapter\.(linear_in|linear_out)\.weight$" + ) + match = re.match(pattern, name) + if match is None: + return [] + + layer_idx, module_name, adapter_part = match.groups() + base_prefix = f"base_model.model.model.layers.{layer_idx}" + + if module_name == "self_attention.linear_proj": + hf_base = f"{base_prefix}.self_attn.o_proj" + suffix = ( + "lora_A.default.weight" + if adapter_part == "linear_in" + else "lora_B.default.weight" + ) + return [(f"{hf_base}.{suffix}", tensor)] + + if module_name == "mlp.linear_fc2": + hf_base = f"{base_prefix}.mlp.down_proj" + suffix = ( + "lora_A.default.weight" + if adapter_part == "linear_in" + else "lora_B.default.weight" + ) + return [(f"{hf_base}.{suffix}", tensor)] + + if module_name == "mlp.linear_fc1": + gate_base = f"{base_prefix}.mlp.gate_proj" + up_base = f"{base_prefix}.mlp.up_proj" + if adapter_part == "linear_in": + return [ + (f"{gate_base}.lora_A.default.weight", tensor), + (f"{up_base}.lora_A.default.weight", tensor), + ] + gate_b, up_b = tensor.chunk(2, dim=0) + return [ + (f"{gate_base}.lora_B.default.weight", gate_b.contiguous()), + (f"{up_base}.lora_B.default.weight", up_b.contiguous()), + ] + + if module_name == "self_attention.linear_qkv": + q_base = f"{base_prefix}.self_attn.q_proj" + k_base = f"{base_prefix}.self_attn.k_proj" + v_base = f"{base_prefix}.self_attn.v_proj" + if adapter_part == "linear_in": + return [ + (f"{q_base}.lora_A.default.weight", tensor), + (f"{k_base}.lora_A.default.weight", tensor), + (f"{v_base}.lora_A.default.weight", tensor), + ] + + head_dim = ( + tf_config.kv_channels + if getattr(tf_config, "kv_channels", None) is not None + else tf_config.hidden_size // tf_config.num_attention_heads + ) + if getattr(tf_config, "num_query_groups", None) is None: + return [] + value_num_per_group = ( + tf_config.num_attention_heads // tf_config.num_query_groups + ) + + tensor = tensor.view(tf_config.num_query_groups, -1, head_dim, tensor.shape[1]) + q_b, k_b, v_b = torch.split(tensor, [value_num_per_group, 1, 1], dim=1) + + q_b = q_b.reshape(-1, q_b.shape[-1]).contiguous() + k_b = k_b.reshape(-1, k_b.shape[-1]).contiguous() + v_b = v_b.reshape(-1, v_b.shape[-1]).contiguous() + + return [ + (f"{q_base}.lora_B.default.weight", q_b), + (f"{k_base}.lora_B.default.weight", k_b), + (f"{v_base}.lora_B.default.weight", v_b), + ] + + return [] + + +def _infer_target_modules_from_adapter_weights(weight_keys: Iterable[str]) -> list[str]: + """ + Infer PEFT target_modules from adapter weight parameter names. + + Extracts module names from HF LoRA weight keys like: + - base_model.model.layers.0.self_attn.q_proj.lora_A.weight -> q_proj + - base_model.model.layers.1.mlp.gate_proj.lora_B.weight -> gate_proj + """ + target_modules = set() + + for key in weight_keys: + # Remove PEFT prefix + key = key.replace("base_model.model.", "") + + # Look for .lora_A.weight or .lora_B.weight pattern + if ".lora_A.weight" in key: + # Extract module name before .lora_A.weight + base_name = key.replace(".lora_A.weight", "") + module_name = base_name.split(".")[-1] + target_modules.add(module_name) + elif ".lora_B.weight" in key: + # Extract module name before .lora_B.weight + base_name = key.replace(".lora_B.weight", "") + module_name = base_name.split(".")[-1] + target_modules.add(module_name) + + return sorted(list(target_modules)) + + +def _build_adapter_config_dict( + peft_config, + target_modules: list[str], + base_model_name_or_path: str, +) -> dict: + """ + Build PEFT adapter_config.json dictionary. + + Creates a config compatible with HuggingFace PEFT library. + """ + return { + "base_model_name_or_path": base_model_name_or_path, + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": peft_config.dim, + "lora_alpha": peft_config.alpha, + "lora_dropout": peft_config.dropout, + "target_modules": target_modules, + "bias": "none", + "fan_in_fan_out": False, + "modules_to_save": None, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + } + + +def _monkey_patch_save_hf_adapter(): + """Add save_hf_adapter to AutoBridge when megatron-bridge does not provide it.""" + from megatron.bridge import AutoBridge + + if hasattr(AutoBridge, "save_hf_adapter"): + # Already exists, no need to patch + return + + def save_hf_adapter( + self, + model, + path: str | Path, + peft_config, + base_model_name_or_path: str | None = None, + show_progress: bool = True, + ) -> None: + """ + Save LoRA adapter weights as a HuggingFace PEFT-compatible directory. + + The output directory contains adapter_config.json and adapter_model.safetensors + and can be loaded directly with peft.PeftModel.from_pretrained(base_model, path). + + Args: + model: Megatron model instance or list of instances. + path: Directory path where the adapter files will be saved. + peft_config: The LoRA config used during training (provides dim, alpha, dropout, etc.). + base_model_name_or_path: HuggingFace model identifier or local path of the base model. + If None, inferred from hf_pretrained.model_name_or_path. + show_progress: Display progress bar during export. + + Example: + >>> bridge.save_hf_adapter( + ... megatron_model, + ... "./my-lora-adapter", + ... peft_config=lora, + ... base_model_name_or_path="Qwen/Qwen3-4B", + ... ) + >>> # Load with HuggingFace PEFT + >>> from peft import PeftModel + >>> from transformers import AutoModelForCausalLM + >>> base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B") + >>> model = PeftModel.from_pretrained(base, "./my-lora-adapter") + + Note: + This method is collective -- all ranks must call it. Only rank 0 writes files. + """ + import json + + from safetensors.torch import save_file + + # Synchronize at start + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + # Export adapter weights + adapter_state: dict[str, torch.Tensor] = {} + for name, tensor in self.export_adapter_weights( + model, cpu=True, show_progress=show_progress + ): + adapter_state[f"base_model.model.{name}"] = tensor.clone().float() + + if not adapter_state: + raise RuntimeError( + "No adapter weights were found on the model. " + "Ensure the model has PEFT adapters applied before calling save_hf_adapter()." + ) + + # Only rank 0 writes files + is_rank0 = ( + not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0 + ) + if is_rank0: + save_dir = Path(path) + save_dir.mkdir(parents=True, exist_ok=True) + + # Infer base model path if not provided + if base_model_name_or_path is None: + base_model_name_or_path = str( + getattr(self.hf_pretrained, "model_name_or_path", "") + or getattr(self.hf_pretrained, "name_or_path", "") + ) + + # Build adapter config + target_modules = _infer_target_modules_from_adapter_weights( + adapter_state.keys() + ) + adapter_config = _build_adapter_config_dict( + peft_config, + target_modules=target_modules, + base_model_name_or_path=base_model_name_or_path, + ) + + # Save adapter config + config_path = save_dir / "adapter_config.json" + with open(config_path, "w") as f: + json.dump(adapter_config, f, indent=2) + + # Save adapter weights + weights_path = save_dir / "adapter_model.safetensors" + save_file(adapter_state, str(weights_path)) + + print(f"✓ Saved LoRA adapter to {save_dir}") + print(f" - Config: {config_path}") + print(f" - Weights: {weights_path} ({len(adapter_state)} parameters)") + + # Synchronize at end + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + # Attach the method to the class + AutoBridge.save_hf_adapter = save_hf_adapter + + +# Current: This monkey patch is needed as the current megatron-bridge 0.3.0 does not have a built-in method +# to save LoRA adapters in HuggingFace PEFT format, which is required for our use case. +# Future: This code is however present in main branch of megatron-bridge so this patch is temporary +# and can be removed later when we upgrade the megatron-bridge version. +_monkey_patch_save_hf_adapter() diff --git a/areal/experimental/openai/client.py b/areal/experimental/openai/client.py index 4a6d47b8b1..5b8b391678 100644 --- a/areal/experimental/openai/client.py +++ b/areal/experimental/openai/client.py @@ -746,6 +746,7 @@ async def create( self.tool_call_parser, self.reasoning_parser, response.stop_reason, + tokenizer=self.tokenizer, ) except json.JSONDecodeError as e: logger.warning( @@ -1103,6 +1104,7 @@ async def create( self.reasoning_parser, engine_resp.stop_reason, use_responses=True, + tokenizer=self.tokenizer, ) except json.JSONDecodeError as e: logger.warning( diff --git a/areal/experimental/openai/tool_call_parser.py b/areal/experimental/openai/tool_call_parser.py index 3fe157b135..59e038c8b3 100644 --- a/areal/experimental/openai/tool_call_parser.py +++ b/areal/experimental/openai/tool_call_parser.py @@ -1,5 +1,6 @@ import traceback import uuid +from types import SimpleNamespace from typing import Any from openai.types.chat.chat_completion_message_function_tool_call import ( @@ -12,6 +13,21 @@ logger = logging.getLogger("ToolCallParser") +_SGLANG_TO_VLLM_TOOL_PARSER: dict[str, str] = { + "qwen": "qwen3_xml", + "qwen25": "qwen3_xml", + "qwen3": "qwen3_xml", + "qwen3_xml": "qwen3_xml", + "qwen3_coder": "qwen3_coder", + "hermes": "hermes", + "llama3": "llama3_json", + "llama3_json": "llama3_json", + "llama4_json": "llama4_json", + "mistral": "mistral", + "openai": "openai", + "deepseek_v3": "deepseek_v3", +} + def _detect_think_and_return_ori_think( text: str, think_start_token: str, think_end_token: str @@ -40,8 +56,7 @@ def _detect_think_and_return_ori_think( return think_start_token + reasoning_text + think_end_token, normal_text -# Modified from sglang -def process_tool_calls( +def _process_tool_calls_sglang( text: str, tools: list[Any], tool_call_parser: str, @@ -53,7 +68,6 @@ def process_tool_calls( str, str, ]: - """Process tool calls in the response""" from sglang.srt.entrypoints.openai.protocol import Function as SglFunction from sglang.srt.entrypoints.openai.protocol import Tool as SglTool from sglang.srt.function_call.function_call_parser import FunctionCallParser @@ -120,7 +134,163 @@ def process_tool_calls( except Exception as e: logger.error(f"Tool call parsing error: {e}") traceback.print_exc() - # Return error but don't fail the whole request return None, text, finish_reason return None, text, finish_reason + + +def _process_tool_calls_vllm( + text: str, + tools: list[Any], + tool_call_parser: str, + reasoning_parser: str, + finish_reason: str, + use_responses: bool = False, + tokenizer: Any = None, +) -> tuple[ + list[ChatCompletionMessageFunctionToolCall | ResponseFunctionToolCall] | None, + str, + str, +]: + from vllm.reasoning import ReasoningParserManager + from vllm.tool_parsers import ToolParserManager + + # Use vllm's reasoning parser to get the think start/end tokens, + # mirroring the sglang path which uses ReasoningParser.detector tokens. + if tokenizer is not None and reasoning_parser: + try: + reasoning_parser_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_parser + ) + reasoning_parser_inst = reasoning_parser_cls(tokenizer) + if hasattr(reasoning_parser_inst, "start_token") and hasattr( + reasoning_parser_inst, "end_token" + ): + reasoning_text, content_text = _detect_think_and_return_ori_think( + text, + reasoning_parser_inst.start_token, + reasoning_parser_inst.end_token, + ) + else: + reasoning_text, content_text = "", text + except Exception as e: + logger.warning( + "Failed to initialize vLLM reasoning parser '%s': %s. " + "Skipping reasoning extraction.", + reasoning_parser, + e, + ) + reasoning_text, content_text = "", text + else: + reasoning_text, content_text = "", text + + vllm_name = _SGLANG_TO_VLLM_TOOL_PARSER.get(tool_call_parser, tool_call_parser) + try: + tool_parser_cls = ToolParserManager.get_tool_parser(vllm_name) + except KeyError: + logger.warning( + "vLLM tool parser '%s' (mapped from '%s') not found; skipping tool call parsing.", + vllm_name, + tool_call_parser, + ) + return None, text, finish_reason + + if tokenizer is None: + logger.warning( + "vLLM tool parser requires a tokenizer but none was provided; skipping tool call parsing." + ) + return None, text, finish_reason + + tool_parser = tool_parser_cls(tokenizer) + request = SimpleNamespace( + tools=tools, + tool_choice=None, + skip_special_tokens=True, + ) + + try: + tool_call_info = tool_parser.extract_tool_calls(content_text, request) + except Exception as e: + logger.error("vLLM tool call parsing error: %s", e) + traceback.print_exc() + return None, text, finish_reason + + if not tool_call_info.tools_called: + return None, text, finish_reason + + if finish_reason == "stop": + finish_reason = "tool_calls" + + remaining_content = tool_call_info.content or "" + + if use_responses: + result_tool_calls = [ + ResponseFunctionToolCall( + type="function_call", + id=f"fc-{uuid.uuid4().hex[:24]}", + call_id=f"call_{uuid.uuid4().hex[:24]}", + name=tc.function.name, + arguments=tc.function.arguments, + status="completed", + ) + for tc in tool_call_info.tool_calls + ] + else: + result_tool_calls = [ + ChatCompletionMessageFunctionToolCall( + type="function", + id=f"call_{uuid.uuid4().hex[:24]}", + function=Function( + name=tc.function.name, + arguments=tc.function.arguments, + ), + ) + for tc in tool_call_info.tool_calls + ] + + return result_tool_calls, reasoning_text + remaining_content, finish_reason + + +def process_tool_calls( + text: str, + tools: list[Any], + tool_call_parser: str, + reasoning_parser: str, + finish_reason: str, + use_responses: bool = False, + tokenizer: Any = None, +) -> tuple[ + list[ChatCompletionMessageFunctionToolCall | ResponseFunctionToolCall] | None, + str, + str, +]: + """Process tool calls in the response""" + try: + return _process_tool_calls_sglang( + text, + tools, + tool_call_parser, + reasoning_parser, + finish_reason, + use_responses, + ) + except ModuleNotFoundError: + pass + + try: + return _process_tool_calls_vllm( + text, + tools, + tool_call_parser, + reasoning_parser, + finish_reason, + use_responses, + tokenizer=tokenizer, + ) + except ModuleNotFoundError: + pass + + logger.warning( + "Neither sglang nor vllm is installed; skipping tool call parsing. Install one of them for tool call support." + ) + return None, text, finish_reason diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index ad189b6f84..9caaf37e14 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -161,6 +161,7 @@ def make_mcore_model( bridge: Bridge | Any | None = None, bridge_type: str = "mbridge", is_critic: bool = False, + use_lora: bool = False, ) -> list[GPTModel | DDP]: if bridge is not None and bridge_type == "mbridge": models = bridge.get_model( @@ -214,6 +215,11 @@ def make_mcore_model( provider.account_for_embedding_in_pipeline_split = False provider.account_for_loss_in_pipeline_split = False + # LoRA params are injected after model materialization and do not carry + # Megatron main_grad buffers required by fused grad accumulation kernels. + if use_lora: + provider.gradient_accumulation_fusion = False + # Keep these four flags aligned with mbridge base defaults. provider.variable_seq_lengths = True logger.warning( @@ -235,8 +241,14 @@ def make_mcore_model( provider.finalize() + ddp_config = MCoreDDPConfig(**dataclasses.asdict(mcore_config.ddp)) + if use_lora: + ddp_config.use_distributed_optimizer = False + ddp_config.overlap_grad_reduce = False + ddp_config.overlap_param_gather = False + models = provider.provide_distributed_model( - ddp_config=MCoreDDPConfig(**dataclasses.asdict(mcore_config.ddp)), + ddp_config=ddp_config, fp16=tf_config.fp16, bf16=tf_config.bf16, use_megatron_fsdp=mcore_config.use_custom_fsdp, diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 1e080fa25d..98c7c50ce6 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -243,27 +243,30 @@ def __init__( self.weight_update_meta = WeightUpdateMeta.from_disk(**disk_kwargs) elif self.config.actor.weight_update_mode == "xccl": # NCCL/XCCL weight update + xccl_kwargs: dict[str, Any] = { + "gen_allocation": self.rollout_alloc, + } + + if config.actor.use_lora: + xccl_kwargs.update( + { + "use_lora": config.actor.use_lora, + "lora_name": config.gconfig.lora_name, + "base_model_name": config.actor.path, + } + ) + if self.actor_alloc.backend == "megatron": self.weight_update_meta = WeightUpdateMeta.from_megatron_xccl( - gen_allocation=self.rollout_alloc, + **xccl_kwargs ) else: - xccl_kwargs: dict[str, Any] = { - "gen_allocation": self.rollout_alloc, - } - if config.actor.use_lora: - xccl_kwargs.update( - { - "use_lora": config.actor.use_lora, - "lora_name": config.gconfig.lora_name, - "base_model_name": config.actor.path, - } - ) self.weight_update_meta = WeightUpdateMeta.from_fsdp_xccl(**xccl_kwargs) else: raise ValueError( f"Invalid weight update mode: {self.config.actor.weight_update_mode}" ) + self.actor.connect_engine(self.rollout, self.weight_update_meta) # Set up evaluation (skip in online mode) @@ -900,11 +903,22 @@ def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int def _validate_cfg(self): """validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models.""" rollout_backend = self.rollout_alloc.backend + actor_backend = self.actor_alloc.backend if rollout_backend == "vllm" and self.config.rollout.return_routed_experts: raise ValueError( "return_routed_experts is only supported with SGLang backend. " "Please disable return_routed_experts or switch to SGLang backend." ) + if ( + actor_backend == "megatron" + and self.config.actor.use_lora + and rollout_backend == "sglang" + ): + raise ValueError( + "Megatron actor with LoRA is not supported with SGLang rollout in " + "RL trainer. Please use vLLM rollout backend, or disable LoRA, or " + "switch actor backend from Megatron." + ) def _requires_proxy_workflow(self, workflow: WorkflowLike | None) -> bool: """Check if workflow requires proxy workers (i.e., not a RolloutWorkflow). diff --git a/docs/en/_toc.yml b/docs/en/_toc.yml index f9c24c4c3b..34638d9e7e 100644 --- a/docs/en/_toc.yml +++ b/docs/en/_toc.yml @@ -42,6 +42,7 @@ parts: - file: reference/checkpointing - file: reference/metrics_tracking - file: reference/alloc_mode + - file: reference/lora - file: reference/bridge_backend - file: reference/tree_training - file: reference/rollout_workflow diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 3f05178349..a01ab8ca85 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -353,7 +353,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | @@ -420,7 +420,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | @@ -460,7 +460,7 @@ Core configuration for model training, including optimization and backend settin | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | @@ -759,14 +759,14 @@ Configuration for experiment statistics logging and tracking services. Configuration for SwanLab experiment tracking and monitoring. -| Parameter | Type | Default | Description | -| --------- | -------------- | ------------ | ----------- | -| `project` | string \| None | `None` | - | -| `name` | string \| None | `None` | - | -| `config` | `dict` \| None | `None` | - | -| `logdir` | string \| None | `None` | - | -| `mode` | string \| None | `"disabled"` | - | -| `api_key` | string \| None | `None` | - | +| Parameter | Type | Default | Description | +| --------- | -------------- | ------------ | ---------------------------------------------------------------------------------------------------------------------- | +| `project` | string \| None | `None` | - | +| `name` | string \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | +| `logdir` | string \| None | `None` | - | +| `mode` | string | `"disabled"` | Tracking mode. One of 'cloud', 'local', 'disabled', or 'offline'. **Choices:** `cloud`, `local`, `disabled`, `offline` | +| `api_key` | string \| None | `None` | - | (section-tensor-board)= @@ -805,20 +805,20 @@ See: https://github.com/gradio-app/trackio Configuration for Weights & Biases experiment tracking. -| Parameter | Type | Default | Description | -| ---------------- | ---------------------- | ------------ | ----------- | -| `mode` | string | `"disabled"` | - | -| `wandb_base_url` | string | `""` | - | -| `wandb_api_key` | string | `""` | - | -| `entity` | string \| None | `None` | - | -| `project` | string \| None | `None` | - | -| `name` | string \| None | `None` | - | -| `job_type` | string \| None | `None` | - | -| `group` | string \| None | `None` | - | -| `notes` | string \| None | `None` | - | -| `tags` | list of string \| None | `None` | - | -| `config` | `dict` \| None | `None` | - | -| `id_suffix` | string \| None | `"train"` | - | +| Parameter | Type | Default | Description | +| ---------------- | ---------------------- | ------------ | -------------------------------------------------------------------------------------------------------------------------- | +| `mode` | string | `"disabled"` | Tracking mode. One of 'online', 'offline', 'disabled', or 'shared'. **Choices:** `online`, `offline`, `disabled`, `shared` | +| `wandb_base_url` | string | `""` | - | +| `wandb_api_key` | string | `""` | - | +| `entity` | string \| None | `None` | - | +| `project` | string \| None | `None` | - | +| `name` | string \| None | `None` | - | +| `job_type` | string \| None | `None` | - | +| `group` | string \| None | `None` | - | +| `notes` | string \| None | `None` | - | +| `tags` | list of string \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | +| `id_suffix` | string \| None | `"train"` | - | (section-archon-engine)= @@ -1067,7 +1067,7 @@ Configuration class: TeacherConfig | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | diff --git a/docs/en/reference/lora.md b/docs/en/reference/lora.md new file mode 100644 index 0000000000..99e7dfd1ae --- /dev/null +++ b/docs/en/reference/lora.md @@ -0,0 +1,54 @@ +# LoRA Reference + +LoRA is a parameter-efficient fine-tuning technique that injects trainable low-rank +matrices into pre-trained weights, typically around linear layers. Compared with +full-parameter fine-tuning, this reduces memory usage and compute cost substantially, +making RL fine-tuning of large models much more practical on limited hardware. + +In AReaL, this is especially useful for: + +- reinforcement learning with very large models, including 70B+ models, on relatively + modest hardware such as 8 x 80 GB GPUs, +- enabling larger batch sizes because LoRA reduces training memory pressure, +- simplifying transfer and deployment because only the LoRA adapters need to be saved + and shipped, +- \[Future\] fine-tune multiple LoRA adapters more efficiently in parallel for better + hardware utilization (see RFC + [#609](https://github.com/inclusionAI/AReaL/issues/609)). + +This guide explains how to enable LoRA in RL training and configure the related +parameters. + +## Backend Support + +The current LoRA support matrix in AReaL is: + +| Engine | vLLM | SGLang | +| -------- | ---- | ------ | +| FSDP2 | ✅ | ✅ | +| Megatron | ✅ | ❌ | +| Archon | ❌ | ❌ | + +Example scripts: + +| Engine | Example script | +| -------- | --------------------------------------------- | +| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` | +| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` | + +## Core LoRA Parameters + +| Parameter | What it controls | Typical values | +| ----------------- | ------------------------------------------------------------------------------------------------------- | --------------------- | +| `use_lora` | Enables LoRA fine-tuning mode. | `true` / `false` | +| `lora_rank` (`r`) | Rank of the low-rank adapters. Higher rank increases capacity and memory/compute cost. | `8`, `16`, `32`, `64` | +| `lora_alpha` | LoRA scaling factor. Effective adapter scale is commonly thought of as proportional to `alpha / r`. | `16`, `32`, `64` | +| `target_modules` | Which model submodules receive LoRA adapters. This is the most important architecture-specific setting. | e.g. \[`all-linear`\] | +| `peft_type` | PEFT method type. In AReaL configs, this is LoRA. | `lora` | + +## Practical Notes + +- Start with `r=16` or `r=32` for most models, then tune upward only if needed. +- Keep `target_modules` consistent with your model architecture naming. +- Currently only dense models (non MoE) are supported. +- For Megatron backend, LoRA requires `megatron-bridge` instead of `mbridge`. diff --git a/docs/zh/_toc.yml b/docs/zh/_toc.yml index 995e0b4f18..1a33f9849e 100644 --- a/docs/zh/_toc.yml +++ b/docs/zh/_toc.yml @@ -42,6 +42,7 @@ parts: - file: reference/checkpointing - file: reference/metrics_tracking - file: reference/alloc_mode + - file: reference/lora - file: reference/bridge_backend - file: reference/tree_training - file: reference/rollout_workflow diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index 21fb56f489..f68accdcb5 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -351,7 +351,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | @@ -418,7 +418,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | @@ -458,7 +458,7 @@ Core configuration for model training, including optimization and backend settin | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | @@ -757,14 +757,14 @@ Configuration for experiment statistics logging and tracking services. Configuration for SwanLab experiment tracking and monitoring. -| Parameter | Type | Default | Description | -| --------- | -------------- | ------------ | ----------- | -| `project` | string \| None | `None` | - | -| `name` | string \| None | `None` | - | -| `config` | `dict` \| None | `None` | - | -| `logdir` | string \| None | `None` | - | -| `mode` | string \| None | `"disabled"` | - | -| `api_key` | string \| None | `None` | - | +| Parameter | Type | Default | Description | +| --------- | -------------- | ------------ | ---------------------------------------------------------------------------------------------------------------------- | +| `project` | string \| None | `None` | - | +| `name` | string \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | +| `logdir` | string \| None | `None` | - | +| `mode` | string | `"disabled"` | Tracking mode. One of 'cloud', 'local', 'disabled', or 'offline'. **Choices:** `cloud`, `local`, `disabled`, `offline` | +| `api_key` | string \| None | `None` | - | (section-tensor-board)= @@ -803,20 +803,20 @@ See: https://github.com/gradio-app/trackio Configuration for Weights & Biases experiment tracking. -| Parameter | Type | Default | Description | -| ---------------- | ---------------------- | ------------ | ----------- | -| `mode` | string | `"disabled"` | - | -| `wandb_base_url` | string | `""` | - | -| `wandb_api_key` | string | `""` | - | -| `entity` | string \| None | `None` | - | -| `project` | string \| None | `None` | - | -| `name` | string \| None | `None` | - | -| `job_type` | string \| None | `None` | - | -| `group` | string \| None | `None` | - | -| `notes` | string \| None | `None` | - | -| `tags` | list of string \| None | `None` | - | -| `config` | `dict` \| None | `None` | - | -| `id_suffix` | string \| None | `"train"` | - | +| Parameter | Type | Default | Description | +| ---------------- | ---------------------- | ------------ | -------------------------------------------------------------------------------------------------------------------------- | +| `mode` | string | `"disabled"` | Tracking mode. One of 'online', 'offline', 'disabled', or 'shared'. **Choices:** `online`, `offline`, `disabled`, `shared` | +| `wandb_base_url` | string | `""` | - | +| `wandb_api_key` | string | `""` | - | +| `entity` | string \| None | `None` | - | +| `project` | string \| None | `None` | - | +| `name` | string \| None | `None` | - | +| `job_type` | string \| None | `None` | - | +| `group` | string \| None | `None` | - | +| `notes` | string \| None | `None` | - | +| `tags` | list of string \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | +| `id_suffix` | string \| None | `"train"` | - | (section-archon-engine)= @@ -1065,7 +1065,7 @@ Configuration class: TeacherConfig | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Supported by FSDP and Megatron (Megatron requires `megatron.bridge_type=megatron-bridge`). For rollout engines, enable LoRA in vLLM/SGLang as well. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | diff --git a/docs/zh/reference/lora.md b/docs/zh/reference/lora.md new file mode 100644 index 0000000000..71b8344ef6 --- /dev/null +++ b/docs/zh/reference/lora.md @@ -0,0 +1,48 @@ +# LoRA 参考 + +LoRA 是一种参数高效的微调技术,会在预训练权重中注入可训练的低秩矩阵, 通常作用在线性层附近。与全参数微调相比,LoRA 可以显著降低显存占用和 计算开销,从而让大模型的 +RL 微调在硬件资源有限的条件下也更具可行性。 + +在 AReaL 中,LoRA 尤其适用于以下场景: + +- 在相对有限的硬件条件下进行超大模型的强化学习训练,例如使用 8 x 80 GB GPU 训练 70B+ 规模模型, +- 由于显存压力更低,可以支持更大的 batch size, +- 模型迁移与部署更加简单,因为只需要保存和分发 LoRA adapter, +- \[Future\] 更高效地并行微调多个 LoRA adapter,以提升硬件利用率 (参见 RFC + [#609](https://github.com/inclusionAI/AReaL/issues/609))。 + +本文档说明如何在 RL 训练中启用 LoRA,并配置相关参数。 + +## 后端支持 + +AReaL 当前的 LoRA 支持矩阵如下: + +| Engine | vLLM | SGLang | +| -------- | ---- | ------ | +| FSDP2 | ✅ | ✅ | +| Megatron | ✅ | ❌ | +| Archon | ❌ | ❌ | + +示例脚本: + +| Engine | Example script | +| -------- | --------------------------------------------- | +| FSDP2 | `examples/math/gsm8k_grpo_lora.yaml` | +| Megatron | `examples/math/gsm8k_grpo_megatron_lora.yaml` | + +## 核心 LoRA 参数 + +| 参数 | 作用 | 常见取值 | +| ----------------- | ------------------------------------------------------------------ | --------------------- | +| `use_lora` | 是否启用 LoRA 微调模式。 | `true` / `false` | +| `lora_rank` (`r`) | 低秩适配器的秩。`r` 越大,表达能力越强,但显存与计算开销更高。 | `8`, `16`, `32`, `64` | +| `lora_alpha` | LoRA 缩放系数。通常可理解为有效缩放与 `alpha / r` 成正比。 | `16`, `32`, `64` | +| `target_modules` | 指定注入 LoRA 的目标子模块。这是最关键、且与模型结构强相关的配置。 | 例如 \[`all-linear`\] | +| `peft_type` | PEFT 方法类型。在 AReaL 配置中为 LoRA。 | `lora` | + +## 实践建议 + +- 可先从 `r=16` 或 `r=32` 开始,再按效果和资源逐步调参。 +- `target_modules` 需与具体模型的层命名保持一致。 +- 当前仅支持 dense 模型(非 MoE)。 +- 对于 Megatron 后端,LoRA 需要使用 `megatron-bridge`,而不是 `mbridge`。 diff --git a/examples/experimental/inference_service/human_in_the_loop_demo.py b/examples/experimental/inference_service/human_in_the_loop_demo.py index 84cb2dea7b..45485176af 100644 --- a/examples/experimental/inference_service/human_in_the_loop_demo.py +++ b/examples/experimental/inference_service/human_in_the_loop_demo.py @@ -36,6 +36,7 @@ DEFAULT_REQUEST_TIMEOUT = 3600 DEFAULT_GATEWAY_WAIT_SECS = 600 DEFAULT_QUESTION = "how many r's are in the word strawberry?" +DEFAULT_INFERENCE_BACKEND = "sglang" CORRECT_ANSWER_RE = re.compile(r"\b3\b|three", re.IGNORECASE) BATCH_SIZE = 4 ROLLOUT_COMPLETE_WAIT_SECS = 60 @@ -207,6 +208,12 @@ def main() -> None: parser.add_argument( "--question", default=DEFAULT_QUESTION, help="Question for each HITL round" ) + parser.add_argument( + "--inference-backend", + choices=("sglang", "vllm"), + default=DEFAULT_INFERENCE_BACKEND, + help="Inference backend used by online_rollout.py", + ) args = parser.parse_args() online_rollout = ( @@ -248,6 +255,8 @@ def cleanup(signum=None, frame=None): "--config", str(config_yaml), f"actor.path={args.actor_path}", + f"rollout.backend={args.inference_backend}:d1", + f"rollout.openai.admin_api_key={args.admin_key}", f"rollout.request_timeout={args.request_timeout}", ], stdout=log_fh, diff --git a/examples/experimental/inference_service/online_rollout.py b/examples/experimental/inference_service/online_rollout.py index 96cfa14dd9..90f1bd077b 100644 --- a/examples/experimental/inference_service/online_rollout.py +++ b/examples/experimental/inference_service/online_rollout.py @@ -14,6 +14,7 @@ def main(args: list[str]) -> None: if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) + from areal.api.alloc_mode import ModelAllocation from areal.api.cli_args import PPOConfig, load_expr_config from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, @@ -37,14 +38,6 @@ def main(args: list[str]) -> None: raise NotImplementedError( "online_rollout.py requires single-controller execution (for example: scheduler.type=local)." ) - from areal.api.alloc_mode import ModelAllocation - - rollout_alloc = ModelAllocation.from_str(config.rollout.backend) - if rollout_alloc.backend == "vllm": - raise NotImplementedError( - "online_rollout.py currently supports only the SGLang generation backend." - ) - from areal.infra.scheduler.local import LocalScheduler from areal.infra.scheduler.slurm import SlurmScheduler @@ -74,12 +67,19 @@ def main(args: list[str]) -> None: request_timeout=config.rollout.request_timeout, openai=openai_cfg, ) + rollout_alloc = ModelAllocation.from_str(config.rollout.backend, name="rollout") + if rollout_alloc.backend == "sglang": + server_args = asdict(config.sglang) + elif rollout_alloc.backend == "vllm": + server_args = asdict(config.vllm) + else: + raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) try: ctrl.initialize( role="rollout", - server_args=asdict(config.sglang), + server_args=server_args, ) logger.info("Proxy gateway available at %s", ctrl.proxy_gateway_addr) diff --git a/examples/experimental/inference_service/tau2_rollout.py b/examples/experimental/inference_service/tau2_rollout.py index 067aca548b..d7c6828216 100644 --- a/examples/experimental/inference_service/tau2_rollout.py +++ b/examples/experimental/inference_service/tau2_rollout.py @@ -18,6 +18,7 @@ from datasets import Dataset +from areal.api.alloc_mode import ModelAllocation from areal.api.cli_args import ( BaseExperimentConfig, GenerationHyperparameters, @@ -210,12 +211,18 @@ def main(argv: list[str]) -> None: raise NotImplementedError(f"Unknown scheduler type: {sched_type}") # --- Controller --- - sglang_args = asdict(config.sglang) + rollout_alloc = ModelAllocation.from_str(config.rollout.backend, name="rollout") + if rollout_alloc.backend == "sglang": + server_args = asdict(config.sglang) + elif rollout_alloc.backend == "vllm": + server_args = asdict(config.vllm) + else: + raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) ctrl.initialize( role="rollout", - server_args=sglang_args, + server_args=server_args, ) # --- Workflow kwargs (identical to examples/tau2/train.py) --- diff --git a/examples/math/gsm8k_grpo_megatron_lora.yaml b/examples/math/gsm8k_grpo_megatron_lora.yaml new file mode 100644 index 0000000000..3dfff966ac --- /dev/null +++ b/examples/math/gsm8k_grpo_megatron_lora.yaml @@ -0,0 +1,191 @@ +experiment_name: gsm8k-grpo-megatron-lora +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "vllm:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + use_lora: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + lora_name: "lora-gsm8k" + +actor: + backend: "megatron:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 3e-5 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + megatron: + bridge_type: megatron-bridge + weight_update_mode: xccl + use_lora: ${rollout.use_lora} + peft_type: lora + lora_rank: 16 + lora_alpha: 16 + target_modules: [linear_qkv, linear_proj, linear_fc1, linear_fc2] + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + enforce_eager: true + enable_lora: ${rollout.use_lora} + max_lora_rank: ${actor.lora_rank} + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/tests/experimental/openai/test_tool_call_parser.py b/tests/experimental/openai/test_tool_call_parser.py index 4c645f223a..4d3da68501 100644 --- a/tests/experimental/openai/test_tool_call_parser.py +++ b/tests/experimental/openai/test_tool_call_parser.py @@ -1,15 +1,11 @@ -import pytest - -# This module requires sglang (tool_call_parser imports sglang internals) -pytest.importorskip("sglang", reason="sglang is required for tool_call_parser tests") -pytestmark = pytest.mark.sglang +from types import SimpleNamespace -# Tools constructed as Iterable[ChatCompletionToolParam] -from openai.types.chat import ChatCompletionToolParam # noqa: E402 +import pytest +from openai.types.chat import ChatCompletionMessageToolCall, ChatCompletionToolParam -from areal.experimental.openai.tool_call_parser import process_tool_calls # noqa: E402 +from areal.experimental.openai import tool_call_parser as parser_module -tools: list[ChatCompletionToolParam] = [ +TOOLS: list[ChatCompletionToolParam] = [ { "type": "function", "function": { @@ -55,43 +51,29 @@ }, ] +TEXT = ( + '\nOkay, so the user is asking whether the director of "Scary Movie" and the director of "The Preacher\'s Wife" are from the same country. Let me think about how to approach this.\n\n' + 'First, I need to confirm the countries of these directors. I remember that "Scary Movie" is directed by Joe Anderson, and "The Preacher\'s Wife" is directed by David Fincher. Wait, but actually, "The Preacher\'s Wife" is directed by Christopher Nolan and David Fincher? No, I think I mixed up. Let me check my memory. \n\n' + "Wait, no. The Preacher's Wife is a movie directed by Christopher Nolan. And David Fincher directed The Preacher. So the directors are different. Then the user is asking if they are both from the same country. The answer would be no, because the two directors are from different countries. \n\n" + 'But maybe I should verify this to be sure. Since I can use the web search function, I should use the "access" tool to get the URLs of the directors\' websites to confirm. So first, I\'ll search for "director of Scary Movie country" and "director of The Preacher\'s Wife country" to get precise data. Then, analyze the results to see if they have the same country affiliations.\n\n\n' + '\n{"name": "search", "arguments": {"query": "director of Scary Movie country"}}\n\n\n' + '\n{"name": "search", "arguments": {"query": "director of The Preacher\'s Wife country"}}\n<|im_end|>' +) -def test_process_tool_calls_qwen25_chat_completions(): - """ - Validate that process_tool_calls extracts tool calls from assistant text - using the qwen25 parser and returns ChatCompletionMessageToolCall entries - when use_responses=False. - """ - text = ( - '\nOkay, so the user is asking whether the director of "Scary Movie" and the director of "The Preacher\'s Wife" are from the same country. Let me think about how to approach this.\n\n' - 'First, I need to confirm the countries of these directors. I remember that "Scary Movie" is directed by Joe Anderson, and "The Preacher\'s Wife" is directed by David Fincher. Wait, but actually, "The Preacher\'s Wife" is directed by Christopher Nolan and David Fincher? No, I think I mixed up. Let me check my memory. \n\n' - "Wait, no. The Preacher's Wife is a movie directed by Christopher Nolan. And David Fincher directed The Preacher. So the directors are different. Then the user is asking if they are both from the same country. The answer would be no, because the two directors are from different countries. \n\n" - 'But maybe I should verify this to be sure. Since I can use the web search function, I should use the "access" tool to get the URLs of the directors\' websites to confirm. So first, I\'ll search for "director of Scary Movie country" and "director of The Preacher\'s Wife country" to get precise data. Then, analyze the results to see if they have the same country affiliations.\n\n\n' - '\n{"name": "search", "arguments": {"query": "director of Scary Movie country"}}\n\n\n' - '\n{"name": "search", "arguments": {"query": "director of The Preacher\'s Wife country"}}\n<|im_end|>' - ) - - tool_call_parser = "qwen25" - reasoning_parser = "qwen3" - finish_reason = "tool_calls" - use_responses = False +TEXT_WITH_TOOL_CALL_IN_THINKING = ( + '\nOkay, so the user is asking whether the director of "Scary Movie" and the director of "The Preacher\'s Wife" are from the same country. Let me think about how to approach this.\n\n' + 'First, I need to confirm the countries of these directors. I remember that "Scary Movie" is directed by Joe Anderson, and "The Preacher\'s Wife" is directed by David Fincher. Wait, but actually, "The Preacher\'s Wife" is directed by Christopher Nolan and David Fincher? No, I think I mixed up. Let me check my memory. \n\n' + 'Wait, no. The Preacher\'s Wife is a movie directed by Christopher Nolan. And David Fincher directed The Preacher. \n{"name": "search", "arguments": {"query": "aaaa"}}\n\n\n So the directors are different. Then the user is asking if they are both from the same country. The answer would be no, because the two directors are from different countries. \n\n' + 'But maybe I should verify this to be sure. Since I can use the web search function, I should use the "access" tool to get the URLs of the directors\' websites to confirm. So first, I\'ll search for "director of Scary Movie country" and "director of The Preacher\'s Wife country" to get precise data. Then, analyze the results to see if they have the same country affiliations.\n\n\n' + '\n{"name": "search", "arguments": {"query": "director of Scary Movie country"}}\n\n\n' + '\n{"name": "search", "arguments": {"query": "director of The Preacher\'s Wife country"}}\n<|im_end|>' +) - tool_calls, new_text, new_finish_reason = process_tool_calls( - text=text, - tools=tools, - tool_call_parser=tool_call_parser, - reasoning_parser=reasoning_parser, - finish_reason=finish_reason, - use_responses=use_responses, - ) - # Assertions +def _assert_tool_calls(tool_calls, new_text: str, new_finish_reason: str) -> None: assert new_finish_reason == "tool_calls" assert tool_calls is not None, "Tool calls should be detected and returned" assert len(tool_calls) == 2, "Two tool calls should be parsed from the text" - # Validate each parsed call is ChatCompletionMessageToolCall - from openai.types.chat import ChatCompletionMessageToolCall - assert isinstance(tool_calls[0], ChatCompletionMessageToolCall) assert isinstance(tool_calls[1], ChatCompletionMessageToolCall) assert tool_calls[0].type == "function" @@ -105,58 +87,183 @@ def test_process_tool_calls_qwen25_chat_completions(): tool_calls[1].function.arguments == '{"query": "director of The Preacher\'s Wife country"}' ) - # Ensure the returned text no longer contains raw blocks + + +def _run_process_tool_calls(text: str): + return parser_module.process_tool_calls( + text=text, + tools=TOOLS, + tool_call_parser="qwen25", + reasoning_parser="qwen3", + finish_reason="tool_calls", + use_responses=False, + tokenizer=object(), + ) + + +@pytest.mark.sglang +def test_process_tool_calls_qwen25_chat_completions_sglang(): + pytest.importorskip( + "sglang.srt.function_call.function_call_parser", + reason="sglang is required for sglang parser tests", + ) + pytest.importorskip( + "sglang.srt.parser.reasoning_parser", + reason="sglang is required for sglang parser tests", + ) + + tool_calls, new_text, new_finish_reason = _run_process_tool_calls(TEXT) + + _assert_tool_calls(tool_calls, new_text, new_finish_reason) assert "" not in new_text -def test_process_tool_calls_qwen25_chat_completions_with_tool_call_in_thinking(): - """ - Validate that process_tool_calls extracts tool calls from assistant text - using the qwen25 parser and returns ChatCompletionMessageToolCall entries - when use_responses=False. - """ - text = ( - '\nOkay, so the user is asking whether the director of "Scary Movie" and the director of "The Preacher\'s Wife" are from the same country. Let me think about how to approach this.\n\n' - 'First, I need to confirm the countries of these directors. I remember that "Scary Movie" is directed by Joe Anderson, and "The Preacher\'s Wife" is directed by David Fincher. Wait, but actually, "The Preacher\'s Wife" is directed by Christopher Nolan and David Fincher? No, I think I mixed up. Let me check my memory. \n\n' - 'Wait, no. The Preacher\'s Wife is a movie directed by Christopher Nolan. And David Fincher directed The Preacher. \n{"name": "search", "arguments": {"query": "aaaa"}}\n\n\n So the directors are different. Then the user is asking if they are both from the same country. The answer would be no, because the two directors are from different countries. \n\n' - 'But maybe I should verify this to be sure. Since I can use the web search function, I should use the "access" tool to get the URLs of the directors\' websites to confirm. So first, I\'ll search for "director of Scary Movie country" and "director of The Preacher\'s Wife country" to get precise data. Then, analyze the results to see if they have the same country affiliations.\n\n\n' - '\n{"name": "search", "arguments": {"query": "director of Scary Movie country"}}\n\n\n' - '\n{"name": "search", "arguments": {"query": "director of The Preacher\'s Wife country"}}\n<|im_end|>' +@pytest.mark.sglang +def test_process_tool_calls_qwen25_chat_completions_with_tool_call_in_thinking_sglang(): + pytest.importorskip( + "sglang.srt.function_call.function_call_parser", + reason="sglang is required for sglang parser tests", + ) + pytest.importorskip( + "sglang.srt.parser.reasoning_parser", + reason="sglang is required for sglang parser tests", + ) + + tool_calls, new_text, new_finish_reason = _run_process_tool_calls( + TEXT_WITH_TOOL_CALL_IN_THINKING ) - tool_call_parser = "qwen25" - reasoning_parser = "qwen3" - finish_reason = "tool_calls" - use_responses = False + _assert_tool_calls(tool_calls, new_text, new_finish_reason) + assert "" in new_text - tool_calls, new_text, new_finish_reason = process_tool_calls( - text=text, - tools=tools, - tool_call_parser=tool_call_parser, - reasoning_parser=reasoning_parser, - finish_reason=finish_reason, - use_responses=use_responses, + +def _raise_module_not_found(*args, **kwargs): + raise ModuleNotFoundError + + +class FakeReasoningParser: + start_token = "" + end_token = "" + + def __init__(self, tokenizer, *args, **kwargs): + pass + + +def _patch_vllm_parsers(monkeypatch): + tool_parsers = pytest.importorskip( + "vllm.tool_parsers", + reason="vllm is required for vllm parser tests", + ) + reasoning_mod = pytest.importorskip( + "vllm.reasoning", + reason="vllm is required for vllm parser tests", ) + monkeypatch.setattr( + parser_module, "_process_tool_calls_sglang", _raise_module_not_found + ) + monkeypatch.setattr( + reasoning_mod.ReasoningParserManager, + "get_reasoning_parser", + staticmethod(lambda name: FakeReasoningParser), + ) + return tool_parsers - # Assertions - assert new_finish_reason == "tool_calls" - assert tool_calls is not None, "Tool calls should be detected and returned" - assert len(tool_calls) == 2, "Two tool calls should be parsed from the text" - # Validate each parsed call is ChatCompletionMessageToolCall - from openai.types.chat import ChatCompletionMessageToolCall - assert isinstance(tool_calls[0], ChatCompletionMessageToolCall) - assert isinstance(tool_calls[1], ChatCompletionMessageToolCall) - assert tool_calls[0].type == "function" - assert tool_calls[0].function.name == "search" - assert tool_calls[1].function.name == "search" - assert ( - tool_calls[0].function.arguments - == '{"query": "director of Scary Movie country"}' +@pytest.mark.vllm +def test_process_tool_calls_qwen25_chat_completions_vllm( + monkeypatch: pytest.MonkeyPatch, +): + tool_parsers = _patch_vllm_parsers(monkeypatch) + + class FakeParser: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def extract_tool_calls(self, content_text, request): + assert request.skip_special_tokens is True + return SimpleNamespace( + tools_called=True, + content=content_text.replace( + '\n{"name": "search", "arguments": {"query": "director of Scary Movie country"}}\n\n\n', + "", + ).replace( + '\n{"name": "search", "arguments": {"query": "director of The Preacher\'s Wife country"}}\n<|im_end|>', + "", + ), + tool_calls=[ + SimpleNamespace( + function=SimpleNamespace( + name="search", + arguments='{"query": "director of Scary Movie country"}', + ) + ), + SimpleNamespace( + function=SimpleNamespace( + name="search", + arguments='{"query": "director of The Preacher\'s Wife country"}', + ) + ), + ], + ) + + monkeypatch.setattr( + tool_parsers.ToolParserManager, + "get_tool_parser", + staticmethod(lambda name: FakeParser), ) - assert ( - tool_calls[1].function.arguments - == '{"query": "director of The Preacher\'s Wife country"}' + + tool_calls, new_text, new_finish_reason = _run_process_tool_calls(TEXT) + + _assert_tool_calls(tool_calls, new_text, new_finish_reason) + assert "" not in new_text + + +@pytest.mark.vllm +def test_process_tool_calls_qwen25_chat_completions_with_tool_call_in_thinking_vllm( + monkeypatch: pytest.MonkeyPatch, +): + tool_parsers = _patch_vllm_parsers(monkeypatch) + + class FakeParser: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def extract_tool_calls(self, content_text, request): + assert request.skip_special_tokens is True + return SimpleNamespace( + tools_called=True, + content=content_text.replace( + '\n{"name": "search", "arguments": {"query": "director of Scary Movie country"}}\n\n\n', + "", + ).replace( + '\n{"name": "search", "arguments": {"query": "director of The Preacher\'s Wife country"}}\n<|im_end|>', + "", + ), + tool_calls=[ + SimpleNamespace( + function=SimpleNamespace( + name="search", + arguments='{"query": "director of Scary Movie country"}', + ) + ), + SimpleNamespace( + function=SimpleNamespace( + name="search", + arguments='{"query": "director of The Preacher\'s Wife country"}', + ) + ), + ], + ) + + monkeypatch.setattr( + tool_parsers.ToolParserManager, + "get_tool_parser", + staticmethod(lambda name: FakeParser), ) - # Ensure the returned text no longer contains raw blocks + + tool_calls, new_text, new_finish_reason = _run_process_tool_calls( + TEXT_WITH_TOOL_CALL_IN_THINKING + ) + + _assert_tool_calls(tool_calls, new_text, new_finish_reason) assert "" in new_text