diff --git a/miles/backends/megatron_utils/megatron_to_hf/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/__init__.py index d0c8de5839..b0696d2fe3 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/__init__.py @@ -8,6 +8,7 @@ from .qwen3_5 import convert_qwen3_5_to_hf from .qwen3_next import convert_qwen3_next_to_hf from .qwen3moe import convert_qwen3moe_to_hf +from .xllm import convert_xllm_to_hf # TODO unify w/ `convert_to_hf` @@ -32,7 +33,9 @@ def convert_to_hf(args, model_name, name, param, quantization_config=None): # TODO optimize code details def _convert_to_hf_core(args, model_name, name, param): - if "glm4moelite" in model_name or "deepseekv3" in model_name: + if "xllm" in model_name: + converted_named_tensors = convert_xllm_to_hf(args, name, param) + elif "glm4moelite" in model_name or "deepseekv3" in model_name: converted_named_tensors = convert_deepseekv3_to_hf(args, name, param) elif "glm4moe" in model_name: converted_named_tensors = convert_glm4moe_to_hf(args, name, param) diff --git a/miles/backends/megatron_utils/megatron_to_hf/xllm.py b/miles/backends/megatron_utils/megatron_to_hf/xllm.py new file mode 100644 index 0000000000..76c340977f --- /dev/null +++ b/miles/backends/megatron_utils/megatron_to_hf/xllm.py @@ -0,0 +1,87 @@ +import re + +import torch + + +def convert_xllm_to_hf(args, name, param): + """Convert Megatron parameter names/tensors to HuggingFace xLLM format.""" + if name == "module.module.embedding.word_embeddings.weight": + return [("model.embed_tokens.weight", param)] + if name == "module.module.output_layer.weight": + return [("lm_head.weight", param)] + if name == "module.module.decoder.final_layernorm.weight": + return [("model.norm.weight", param)] + + try: + head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads + except AttributeError: + head_dim = args.hidden_size // args.num_attention_heads + value_num_per_group = args.num_attention_heads // args.num_query_groups + + decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, name) + if match: + layer_idx, rest = match.groups() + + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, expert_idx = match.groups() + if rest == "linear_fc1": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight), + ] + if rest == "linear_fc2": + return [(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param)] + raise ValueError(f"Unknown expert parameter name: {name}") + + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest == "linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", up_weight), + ] + if rest == "linear_fc2.weight": + return [(f"model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", param)] + raise ValueError(f"Unknown shared expert parameter name: {name}") + + if rest == "self_attention.linear_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] + if rest == "self_attention.linear_qkv.weight": + param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size) + q_param, k_param, v_param = torch.split(param, [value_num_per_group, 1, 1], dim=1) + q_param = q_param.reshape(-1, args.hidden_size) + k_param = k_param.reshape(-1, args.hidden_size) + v_param = v_param.reshape(-1, args.hidden_size) + return [ + (f"model.layers.{layer_idx}.self_attn.q_proj.weight", q_param), + (f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param), + (f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param), + ] + + if rest == "mlp.linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.up_proj.weight", up_weight), + ] + if rest == "mlp.linear_fc2.weight": + return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] + + if rest in ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"): + return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)] + if rest in ("mlp.linear_fc1.layer_norm_weight", "pre_mlp_layernorm.weight"): + return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + + if rest == "mlp.router.weight": + return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)] + if rest == "mlp.router.expert_bias": + return [(f"model.layers.{layer_idx}.mlp.gate.bias", param)] + + raise ValueError(f"Unknown parameter name: {name}") diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index a9be3696d7..56ec82ff51 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -429,6 +429,12 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p if batch["multimodal_train_inputs"] is not None: forward_kwargs.update(batch["multimodal_train_inputs"]) + # Keep last-stage logits in model precision. Float16Module defaults + # to fp32_output=True, which upcasts the full vocab-sharded logits + # tensor before the PPO loss. For 375B runs with 250k vocab and + # packed long rollouts this can require several extra GiB and OOM. + forward_kwargs["fp32_output"] = not (args.fp16 or args.bf16) + output_tensor = model(**forward_kwargs) for m, old_stage in zip(all_replay_managers, old_stages, strict=True): @@ -471,6 +477,11 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p if valid_step: # Update parameters. + # Long packed RL batches leave large inactive activation/logprob blocks + # in the CUDA caching allocator. TransformerEngine's fused Adam lazily + # creates optimizer state on first step, so release inactive blocks here + # before tiny state allocations fail with reserved-but-free memory. + clear_memory() update_successful, grad_norm, num_zeros_in_grad = optimizer.step() # Update learning rate. @@ -695,6 +706,11 @@ def save( if should_disable_forward_pre_hook(args): disable_forward_pre_hook(model) + # Checkpointing optimizer state can transiently allocate fp32 copies inside + # TransformerEngine FusedAdam.state_dict(). Release inactive train/logprob + # blocks first so final saves do not OOM after otherwise healthy steps. + clear_memory() + if is_lora_model(model): save_checkpoint_with_lora(iteration, model, optimizer, opt_param_scheduler) else: @@ -709,6 +725,8 @@ def save( preprocess_common_state_dict_fn=None, ) + clear_memory() + if hashes is not None: save_model_hashes(args, model, iteration, hashes) if should_disable_forward_pre_hook(args): diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 45adc86792..617068333e 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -166,13 +166,21 @@ def model_provider( else: # Define the decoder layer spec if use_te: - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=args.num_experts, - moe_grouped_gemm=args.moe_grouped_gemm, - qk_layernorm=args.qk_layernorm, - multi_latent_attention=args.multi_latent_attention, - moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, - ) + te_spec_kwargs = { + "num_experts": args.num_experts, + "moe_grouped_gemm": args.moe_grouped_gemm, + "qk_layernorm": args.qk_layernorm, + "multi_latent_attention": args.multi_latent_attention, + "moe_use_legacy_grouped_gemm": args.moe_use_legacy_grouped_gemm, + } + te_spec_params = inspect.signature(get_gpt_layer_with_transformer_engine_spec).parameters + if "fuse_layernorm_and_linear" in te_spec_params: + te_spec_kwargs["fuse_layernorm_and_linear"] = getattr(args, "layernorm_num_groups", 1) == 1 + if "remap_unfused_layernorm_checkpoint_keys" in te_spec_params: + te_spec_kwargs["remap_unfused_layernorm_checkpoint_keys"] = ( + getattr(args, "layernorm_num_groups", 1) == 1 + ) + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(**te_spec_kwargs) else: transformer_layer_spec = get_gpt_layer_local_spec( num_experts=args.num_experts, diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index e0eccde8b4..bd0386e9f4 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -48,7 +48,9 @@ def get_responses( Args: logits: Model outputs with shape `[1, T, V]` (policy) or `[1, T, 1]` - (value). Must be float32. + (value). Policy logits may stay in model precision; logprob code + casts response chunks to fp32 after slicing to avoid full-logit + fp32 materialization. args: Configuration containing `rollout_temperature` for scaling. unconcat_tokens: List of token tensors (prompt+response) per sample. total_lengths: Total sequence lengths (prompt+response) per sample. @@ -62,7 +64,6 @@ def get_responses( parallel_state = get_parallel_state() qkv_format = args.qkv_format - assert logits.dtype == torch.float32, f"{logits.dtype}" assert len(logits.shape) == 3, f"{logits.shape}" if qkv_format == "thd": diff --git a/miles/utils/ppo_utils.py b/miles/utils/ppo_utils.py index 18fb2681a2..634c9d430d 100644 --- a/miles/utils/ppo_utils.py +++ b/miles/utils/ppo_utils.py @@ -649,7 +649,11 @@ def calculate_log_probs_and_entropy( logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, true_on_policy: bool = False ): if true_on_policy: - return _calculate_log_probs_and_entropy_true_on_policy(logits, tokens, with_entropy=with_entropy) + return _calculate_log_probs_and_entropy_true_on_policy( + logits.float() if logits.dtype != torch.float32 else logits, + tokens, + with_entropy=with_entropy, + ) logits = logits.contiguous() # TODO: not sure why we need to clone the logits here. @@ -663,16 +667,22 @@ def calculate_log_probs_and_entropy( logits_chunks = logits.chunk(num_chunks, dim=0) log_probs = [] for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + if logits_chunk.dtype != torch.float32: + logits_chunk = logits_chunk.float() log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group) log_probs.append(log_prob) log_prob = torch.cat(log_probs, dim=0) if with_entropy: entropys = [] for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + if logits_chunk.dtype != torch.float32: + logits_chunk = logits_chunk.float() entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group) entropys.append(entropy) entropy = torch.cat(entropys, dim=0) else: + if logits.dtype != torch.float32: + logits = logits.float() log_prob = compute_log_probs(logits.clone(), tokens, tp_group) if with_entropy: entropy = compute_entropy_from_logits(logits.clone(), tp_group) diff --git a/miles_plugins/mbridge/__init__.py b/miles_plugins/mbridge/__init__.py index 85e79ec0ec..f847f29b05 100644 --- a/miles_plugins/mbridge/__init__.py +++ b/miles_plugins/mbridge/__init__.py @@ -5,6 +5,7 @@ from .mimo import MimoBridge from .qwen3_5 import Qwen3_5Bridge from .qwen3_next import Qwen3NextBridge +from .xllm import XllmBridge __all__ = [ "GLM4Bridge", @@ -14,4 +15,5 @@ "Qwen3_5Bridge", "MimoBridge", "DeepseekV32Bridge", + "XllmBridge", ] diff --git a/miles_plugins/mbridge/xllm.py b/miles_plugins/mbridge/xllm.py new file mode 100644 index 0000000000..8859e94013 --- /dev/null +++ b/miles_plugins/mbridge/xllm.py @@ -0,0 +1,73 @@ +from mbridge.core import register_model +from mbridge.models import Qwen2Bridge, Qwen2MoEBridge + + +@register_model("xllm") +class XllmBridge(Qwen2MoEBridge): + """Bridge implementation for dense and MoE xLLM checkpoints.""" + + _MLP_MAPPING = { + **Qwen2MoEBridge._MLP_MAPPING, + **Qwen2Bridge._MLP_MAPPING, + "mlp.router.expert_bias": ["model.layers.{layer_number}.mlp.gate.bias"], + "shared_experts.linear_fc1.weight": [ + "model.layers.{layer_number}.mlp.shared_experts.gate_proj.weight", + "model.layers.{layer_number}.mlp.shared_experts.up_proj.weight", + ], + "shared_experts.linear_fc2.weight": [ + "model.layers.{layer_number}.mlp.shared_experts.down_proj.weight", + ], + } + + def _has_moe(self): + return (getattr(self.hf_config, "num_experts", 0) or 0) > 0 + + def _build_config(self): + config_kwargs = dict( + use_cpu_initialization=False, + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + qk_layernorm=False, + add_qkv_bias=False, + add_bias_linear=False, + ) + + if self._has_moe(): + config_kwargs.update( + moe_ffn_hidden_size=self.hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0, + moe_router_topk=self.hf_config.num_experts_per_tok, + num_moe_experts=self.hf_config.num_experts, + moe_router_load_balancing_type="none", + moe_grouped_gemm=True, + moe_router_score_function="sigmoid", + moe_router_enable_expert_bias=True, + moe_router_pre_softmax=True, + ) + + return self._build_base_config(**config_kwargs) + + def _weight_name_mapping_mcore_to_hf(self, mcore_weights_name: str) -> list[str]: + assert "_extra_state" not in mcore_weights_name, "extra_state should not be loaded" + direct_name_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + if mcore_weights_name in direct_name_mapping: + return [direct_name_mapping[mcore_weights_name]] + + layer_prefix = "decoder.layers." + if mcore_weights_name.startswith(layer_prefix): + layer_idx, _, rest = mcore_weights_name[len(layer_prefix) :].partition(".") + if rest == "input_layernorm.weight": + return [f"model.layers.{layer_idx}.input_layernorm.weight"] + if rest == "pre_mlp_layernorm.weight": + return [f"model.layers.{layer_idx}.post_attention_layernorm.weight"] + + if "self_attention" in mcore_weights_name: + return self._weight_name_mapping_attention(mcore_weights_name) + if "mlp" in mcore_weights_name: + return self._weight_name_mapping_mlp(mcore_weights_name) + raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}") diff --git a/scripts/models/xllm-8B.sh b/scripts/models/xllm-8B.sh new file mode 100644 index 0000000000..ba59296724 --- /dev/null +++ b/scripts/models/xllm-8B.sh @@ -0,0 +1,20 @@ +# xLLM 8B dense GQA model arguments. +MODEL_ARGS=( + --swiglu + --num-layers 36 + --hidden-size 4096 + --ffn-hidden-size 12288 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 8 + --kv-channels 128 + --disable-bias-linear + --normalization RMSNorm + --norm-epsilon 1e-6 + --layernorm-num-groups 4 + --position-embedding-type rope + --rotary-percent 1.0 + --rotary-base 10000000 + --untie-embeddings-and-output-weights + --vocab-size 250624 +)