Skip to content
5 changes: 3 additions & 2 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import os
import re
import torch.nn.functional as F
from dataclasses import dataclass
from dataclasses import dataclass, field
from megatron.core import mpu
from megatron.core.transformer import TransformerConfig
from transformers import PretrainedConfig
from transformers.utils import is_torch_npu_available
from transformers.utils.versions import require_version
from typing import List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from mcore_bridge.utils import get_logger, json_parse_to_dict

Expand Down Expand Up @@ -229,6 +229,7 @@ class ModelConfig(TransformerConfig):
task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm'
num_labels: Optional[int] = None
mlp_padding_free: bool = False
model_kwargs: Dict[str, Any] = field(default_factory=dict)

_mindspeed_defaults_cache = None

Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MLLMModelType:
glm4v_moe = 'glm4v_moe'
kimi_vl = 'kimi_vl'
llama4 = 'llama4'
gemma4 = 'gemma4'

kimi_k25 = 'kimi_k25'

Expand Down
65 changes: 37 additions & 28 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def __init__(
for i in range(len(self.decoder.layers)):
if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'):
del self.decoder.layers[i].self_attention.rotary_pos_emb
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
self._set_inv_freq()
if self.config.task_type == 'seq_cls' and self.post_process:
self.output_layer = OutputLayerLinear(
config.hidden_size,
Expand Down Expand Up @@ -212,7 +210,36 @@ def _preprocess(
if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
# fix LoRA incompatibility with gradient checkpointing
decoder_input = decoder_input.requires_grad_(True)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
Comment on lines +213 to +214
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inference_context is not passed to the _get_rotary_pos_emb method. This will cause the method to skip critical inference-specific logic, such as utilizing the RoPE cache or correctly calculating the rotary sequence length for flash decoding, which can lead to performance degradation or incorrect results during inference.

Suggested change
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params, inference_context=inference_context)


if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin,
sequence_len_offset)

def _set_inv_freq(self):
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)

def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
Expand Down Expand Up @@ -247,26 +274,13 @@ def _preprocess(
rotary_seq_len,
packed_seq=packed_seq,
)
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin

# Code borrowed from NVIDIA/Megatron-LM
def forward(
Expand Down Expand Up @@ -298,19 +312,14 @@ def forward(

inference_context = deprecate_inference_params(inference_context, inference_params)

decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
self._preprocess(
input_ids=input_ids,
position_ids=position_ids,
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
))
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

mtp_decoder_input = decoder_input
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/mm_gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
137 changes: 137 additions & 0 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.mlp import MLP
from transformers import AutoModel, PretrainedConfig
from typing import Optional

from mcore_bridge.bridge import MultimodalGPTBridge
from mcore_bridge.config import ModelConfig

from ..constant import ModelType
from ..gpt_model import GPTModel
from ..mm_gpt_model import MultimodalGPTModel
from ..register import ModelLoader, ModelMeta, register_model
from ..rope import get_rope_inv_freq
from .utils import HuggingFaceVit


class Gemma4Vit(HuggingFaceVit):
module_mapping = {
'model.vision_tower': 'vision_tower',
'model.embed_vision': 'embed_vision',
'model.audio_tower': 'audio_tower',
'model.embed_audio': 'embed_audio',
}
_vision_tower = ['vision_tower', 'audio_tower']
_aligner = ['embed_vision', 'embed_audio']
support_multimodal = False

def prepare_model(self, hf_config: PretrainedConfig):
from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder
self.vision_tower = AutoModel.from_config(hf_config.vision_config)
self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None
self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config)
self.embed_audio = (
Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config)
if hf_config.audio_config is not None else None)

def get_inputs_embeds(self, inputs_embeds, **kwargs):
return inputs_embeds


class Gemma4SelfAttention(SelfAttention):

def __init__(
self,
config: ModelConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
*args,
**kwargs,
):
text_config = config.hf_config.text_config
self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention'
self.sliding_window = text_config.sliding_window if self.is_sliding else None
kv_channels = config.kv_channels
config.kv_channels = (
text_config.global_head_dim
if not self.is_sliding and text_config.global_head_dim else text_config.head_dim)
super().__init__(config, submodules, layer_number, *args, **kwargs)
config.kv_channels = kv_channels


class Gemma4MLP(MLP):

def __init__(
self,
config: ModelConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
*args,
**kwargs,
):
self.layer_number = layer_number
text_config = config.hf_config.text_config
self.enable_moe_block = text_config.enable_moe_block
first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
super().__init__(config, submodules, *args, **kwargs)


class Gemma4Bridge(MultimodalGPTBridge):
pass


class Gemma4TextGPTModel(GPTModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
print()

def _set_inv_freq(self):
rope_scaling = self.config.rope_scaling
self.config.rope_scaling = rope_scaling['sliding_attention']
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
assert attention_scaling == 1, 'not support'
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
# full
self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb)
self.config.rope_scaling = rope_scaling['full_attention']
kwargs = {}
if self.config.rope_scaling['rope_type'] == 'proportional':
kwargs['head_dim_key'] = 'global_head_dim'
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs)
assert attention_scaling == 1, 'not support'
self.full_rotary_pos_emb.inv_freq = new_inv_freq
self.attention_scaling = attention_scaling

self.config.rope_scaling = rope_scaling
Comment on lines +94 to +111
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _set_inv_freq for Gemma4TextGPTModel has several issues:

  1. Potential Runtime Crash: Restoring self.config.rope_scaling to the original nested dictionary at line 62 will cause a KeyError in _get_rope_type (called via dynamic_rope_update during every forward pass) because that function expects a dictionary with a rope_type key at the top level, which the Gemma4 configuration lacks (it uses sliding_attention and full_attention as top-level keys).
  2. Dead Code: self.full_rotary_pos_emb is initialized but never utilized by the base GPTModel forward pass or RoPE application logic.
  3. Poor Error Messages: The assertion messages 'not support' at lines 49 and 58 are not descriptive. They should clearly state that attention scaling other than 1.0 is not supported for this model.



class Gemma4GPTModel(MultimodalGPTModel):
language_model_cls = Gemma4TextGPTModel


class Gemma4Loader(ModelLoader):
model_cls = Gemma4GPTModel

def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
layer_specs = get_gpt_decoder_block_spec(
self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage)
for layer_spec in layer_specs.layer_specs:
layer_spec.submodules.self_attention.module = Gemma4SelfAttention
layer_spec.submodules.mlp.module = Gemma4MLP
return layer_specs


register_model(
ModelMeta(
ModelType.gemma4,
['gemma4'],
bridge_cls=Gemma4Bridge,
visual_cls=Gemma4Vit,
loader=Gemma4Loader,
))
3 changes: 3 additions & 0 deletions src/mcore_bridge/model/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
from megatron.core.transformer.moe.moe_layer import MoELayer

from mcore_bridge.model.gpts.glm4 import Glm4MLP
from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP

# MLP expects tp_group but MoELayer expects pg_collection to be passed in.
# We can change MLP to accept pg_collection but it makes the logic implicit
Expand All @@ -148,6 +149,8 @@ def __init__(
elif submodules.mlp.module in (MLP, Glm4MLP):
assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer'
additional_mlp_kwargs['tp_group'] = pg_collection.tp
elif submodules.mlp.module == Gemma4MLP:
additional_mlp_kwargs['layer_number'] = layer_number
elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP:
assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer'
additional_mlp_kwargs['tp_group'] = pg_collection.tp
Expand Down
Loading