Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion miles/backends/megatron_utils/megatron_to_hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .qwen3_5 import convert_qwen3_5_to_hf
from .qwen3_next import convert_qwen3_next_to_hf
from .qwen3moe import convert_qwen3moe_to_hf
from .xllm import convert_xllm_to_hf


# TODO unify w/ `convert_to_hf`
Expand All @@ -32,7 +33,9 @@ def convert_to_hf(args, model_name, name, param, quantization_config=None):

# TODO optimize code details
def _convert_to_hf_core(args, model_name, name, param):
if "glm4moelite" in model_name or "deepseekv3" in model_name:
if "xllm" in model_name:
converted_named_tensors = convert_xllm_to_hf(args, name, param)
elif "glm4moelite" in model_name or "deepseekv3" in model_name:
converted_named_tensors = convert_deepseekv3_to_hf(args, name, param)
elif "glm4moe" in model_name:
converted_named_tensors = convert_glm4moe_to_hf(args, name, param)
Expand Down
87 changes: 87 additions & 0 deletions miles/backends/megatron_utils/megatron_to_hf/xllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import re

import torch


def convert_xllm_to_hf(args, name, param):
"""Convert Megatron parameter names/tensors to HuggingFace xLLM format."""
if name == "module.module.embedding.word_embeddings.weight":
return [("model.embed_tokens.weight", param)]
if name == "module.module.output_layer.weight":
return [("lm_head.weight", param)]
if name == "module.module.decoder.final_layernorm.weight":
return [("model.norm.weight", param)]

try:
head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads
except AttributeError:
head_dim = args.hidden_size // args.num_attention_heads
value_num_per_group = args.num_attention_heads // args.num_query_groups

decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)"
match = re.match(decoder_layers_pattern, name)
if match:
layer_idx, rest = match.groups()

expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
match = re.match(expert_pattern, rest)
if match:
rest, expert_idx = match.groups()
if rest == "linear_fc1":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", gate_weight),
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight),
]
if rest == "linear_fc2":
return [(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param)]
raise ValueError(f"Unknown expert parameter name: {name}")

shared_expert_pattern = r"mlp.shared_experts\.(.+)"
match = re.match(shared_expert_pattern, rest)
if match:
rest = match.groups()[0]
if rest == "linear_fc1.weight":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", gate_weight),
(f"model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", up_weight),
]
if rest == "linear_fc2.weight":
return [(f"model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", param)]
raise ValueError(f"Unknown shared expert parameter name: {name}")

if rest == "self_attention.linear_proj.weight":
return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
if rest == "self_attention.linear_qkv.weight":
param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size)
q_param, k_param, v_param = torch.split(param, [value_num_per_group, 1, 1], dim=1)
q_param = q_param.reshape(-1, args.hidden_size)
k_param = k_param.reshape(-1, args.hidden_size)
v_param = v_param.reshape(-1, args.hidden_size)
return [
(f"model.layers.{layer_idx}.self_attn.q_proj.weight", q_param),
(f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param),
(f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param),
]

if rest == "mlp.linear_fc1.weight":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight),
(f"model.layers.{layer_idx}.mlp.up_proj.weight", up_weight),
]
if rest == "mlp.linear_fc2.weight":
return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)]

if rest in ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"):
return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)]
if rest in ("mlp.linear_fc1.layer_norm_weight", "pre_mlp_layernorm.weight"):
return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)]

if rest == "mlp.router.weight":
return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)]
if rest == "mlp.router.expert_bias":
return [(f"model.layers.{layer_idx}.mlp.gate.bias", param)]

raise ValueError(f"Unknown parameter name: {name}")
18 changes: 18 additions & 0 deletions miles/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
if batch["multimodal_train_inputs"] is not None:
forward_kwargs.update(batch["multimodal_train_inputs"])

# Keep last-stage logits in model precision. Float16Module defaults
# to fp32_output=True, which upcasts the full vocab-sharded logits
# tensor before the PPO loss. For 375B runs with 250k vocab and
# packed long rollouts this can require several extra GiB and OOM.
forward_kwargs["fp32_output"] = not (args.fp16 or args.bf16)

output_tensor = model(**forward_kwargs)

for m, old_stage in zip(all_replay_managers, old_stages, strict=True):
Expand Down Expand Up @@ -471,6 +477,11 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p

if valid_step:
# Update parameters.
# Long packed RL batches leave large inactive activation/logprob blocks
# in the CUDA caching allocator. TransformerEngine's fused Adam lazily
# creates optimizer state on first step, so release inactive blocks here
# before tiny state allocations fail with reserved-but-free memory.
clear_memory()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()

# Update learning rate.
Expand Down Expand Up @@ -695,6 +706,11 @@ def save(
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)

# Checkpointing optimizer state can transiently allocate fp32 copies inside
# TransformerEngine FusedAdam.state_dict(). Release inactive train/logprob
# blocks first so final saves do not OOM after otherwise healthy steps.
clear_memory()

if is_lora_model(model):
save_checkpoint_with_lora(iteration, model, optimizer, opt_param_scheduler)
else:
Expand All @@ -709,6 +725,8 @@ def save(
preprocess_common_state_dict_fn=None,
)

clear_memory()

if hashes is not None:
save_model_hashes(args, model, iteration, hashes)
if should_disable_forward_pre_hook(args):
Expand Down
22 changes: 15 additions & 7 deletions miles/backends/megatron_utils/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,21 @@ def model_provider(
else:
# Define the decoder layer spec
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=args.num_experts,
moe_grouped_gemm=args.moe_grouped_gemm,
qk_layernorm=args.qk_layernorm,
multi_latent_attention=args.multi_latent_attention,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
)
te_spec_kwargs = {
"num_experts": args.num_experts,
"moe_grouped_gemm": args.moe_grouped_gemm,
"qk_layernorm": args.qk_layernorm,
"multi_latent_attention": args.multi_latent_attention,
"moe_use_legacy_grouped_gemm": args.moe_use_legacy_grouped_gemm,
}
te_spec_params = inspect.signature(get_gpt_layer_with_transformer_engine_spec).parameters
if "fuse_layernorm_and_linear" in te_spec_params:
te_spec_kwargs["fuse_layernorm_and_linear"] = getattr(args, "layernorm_num_groups", 1) == 1
if "remap_unfused_layernorm_checkpoint_keys" in te_spec_params:
te_spec_kwargs["remap_unfused_layernorm_checkpoint_keys"] = (
getattr(args, "layernorm_num_groups", 1) == 1
)
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(**te_spec_kwargs)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
num_experts=args.num_experts,
Expand Down
5 changes: 3 additions & 2 deletions miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def get_responses(

Args:
logits: Model outputs with shape `[1, T, V]` (policy) or `[1, T, 1]`
(value). Must be float32.
(value). Policy logits may stay in model precision; logprob code
casts response chunks to fp32 after slicing to avoid full-logit
fp32 materialization.
args: Configuration containing `rollout_temperature` for scaling.
unconcat_tokens: List of token tensors (prompt+response) per sample.
total_lengths: Total sequence lengths (prompt+response) per sample.
Expand All @@ -62,7 +64,6 @@ def get_responses(
parallel_state = get_parallel_state()
qkv_format = args.qkv_format

assert logits.dtype == torch.float32, f"{logits.dtype}"
assert len(logits.shape) == 3, f"{logits.shape}"

if qkv_format == "thd":
Expand Down
12 changes: 11 additions & 1 deletion miles/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,11 @@ def calculate_log_probs_and_entropy(
logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, true_on_policy: bool = False
):
if true_on_policy:
return _calculate_log_probs_and_entropy_true_on_policy(logits, tokens, with_entropy=with_entropy)
return _calculate_log_probs_and_entropy_true_on_policy(
logits.float() if logits.dtype != torch.float32 else logits,
tokens,
with_entropy=with_entropy,
)

logits = logits.contiguous()
# TODO: not sure why we need to clone the logits here.
Expand All @@ -663,16 +667,22 @@ def calculate_log_probs_and_entropy(
logits_chunks = logits.chunk(num_chunks, dim=0)
log_probs = []
for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
if logits_chunk.dtype != torch.float32:
logits_chunk = logits_chunk.float()
log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group)
log_probs.append(log_prob)
log_prob = torch.cat(log_probs, dim=0)
if with_entropy:
entropys = []
for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
if logits_chunk.dtype != torch.float32:
logits_chunk = logits_chunk.float()
entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group)
entropys.append(entropy)
entropy = torch.cat(entropys, dim=0)
else:
if logits.dtype != torch.float32:
logits = logits.float()
log_prob = compute_log_probs(logits.clone(), tokens, tp_group)
if with_entropy:
entropy = compute_entropy_from_logits(logits.clone(), tp_group)
Expand Down
2 changes: 2 additions & 0 deletions miles_plugins/mbridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .mimo import MimoBridge
from .qwen3_5 import Qwen3_5Bridge
from .qwen3_next import Qwen3NextBridge
from .xllm import XllmBridge

__all__ = [
"GLM4Bridge",
Expand All @@ -14,4 +15,5 @@
"Qwen3_5Bridge",
"MimoBridge",
"DeepseekV32Bridge",
"XllmBridge",
]
73 changes: 73 additions & 0 deletions miles_plugins/mbridge/xllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from mbridge.core import register_model
from mbridge.models import Qwen2Bridge, Qwen2MoEBridge


@register_model("xllm")
class XllmBridge(Qwen2MoEBridge):
"""Bridge implementation for dense and MoE xLLM checkpoints."""

_MLP_MAPPING = {
**Qwen2MoEBridge._MLP_MAPPING,
**Qwen2Bridge._MLP_MAPPING,
"mlp.router.expert_bias": ["model.layers.{layer_number}.mlp.gate.bias"],
"shared_experts.linear_fc1.weight": [
"model.layers.{layer_number}.mlp.shared_experts.gate_proj.weight",
"model.layers.{layer_number}.mlp.shared_experts.up_proj.weight",
],
"shared_experts.linear_fc2.weight": [
"model.layers.{layer_number}.mlp.shared_experts.down_proj.weight",
],
}

def _has_moe(self):
return (getattr(self.hf_config, "num_experts", 0) or 0) > 0

def _build_config(self):
config_kwargs = dict(
use_cpu_initialization=False,
persist_layer_norm=True,
bias_activation_fusion=True,
bias_dropout_fusion=True,
qk_layernorm=False,
add_qkv_bias=False,
add_bias_linear=False,
)

if self._has_moe():
config_kwargs.update(
moe_ffn_hidden_size=self.hf_config.moe_intermediate_size,
moe_router_bias_update_rate=0,
moe_router_topk=self.hf_config.num_experts_per_tok,
num_moe_experts=self.hf_config.num_experts,
moe_router_load_balancing_type="none",
moe_grouped_gemm=True,
moe_router_score_function="sigmoid",
moe_router_enable_expert_bias=True,
moe_router_pre_softmax=True,
)

return self._build_base_config(**config_kwargs)

def _weight_name_mapping_mcore_to_hf(self, mcore_weights_name: str) -> list[str]:
assert "_extra_state" not in mcore_weights_name, "extra_state should not be loaded"
direct_name_mapping = {
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
"decoder.final_layernorm.weight": "model.norm.weight",
"output_layer.weight": "lm_head.weight",
}
if mcore_weights_name in direct_name_mapping:
return [direct_name_mapping[mcore_weights_name]]

layer_prefix = "decoder.layers."
if mcore_weights_name.startswith(layer_prefix):
layer_idx, _, rest = mcore_weights_name[len(layer_prefix) :].partition(".")
if rest == "input_layernorm.weight":
return [f"model.layers.{layer_idx}.input_layernorm.weight"]
if rest == "pre_mlp_layernorm.weight":
return [f"model.layers.{layer_idx}.post_attention_layernorm.weight"]

if "self_attention" in mcore_weights_name:
return self._weight_name_mapping_attention(mcore_weights_name)
if "mlp" in mcore_weights_name:
return self._weight_name_mapping_mlp(mcore_weights_name)
raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}")
20 changes: 20 additions & 0 deletions scripts/models/xllm-8B.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# xLLM 8B dense GQA model arguments.
MODEL_ARGS=(
--swiglu
--num-layers 36
--hidden-size 4096
--ffn-hidden-size 12288
--num-attention-heads 32
--group-query-attention
--num-query-groups 8
--kv-channels 128
--disable-bias-linear
--normalization RMSNorm
--norm-epsilon 1e-6
--layernorm-num-groups 4
--position-embedding-type rope
--rotary-percent 1.0
--rotary-base 10000000
--untie-embeddings-and-output-weights
--vocab-size 250624
)
Loading