diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 943c100..5eec352 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -3,13 +3,13 @@ import os import re import torch.nn.functional as F -from dataclasses import dataclass +from dataclasses import dataclass, field from megatron.core import mpu from megatron.core.transformer import TransformerConfig from transformers import PretrainedConfig from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version -from typing import List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mcore_bridge.utils import get_logger, json_parse_to_dict @@ -229,6 +229,7 @@ class ModelConfig(TransformerConfig): task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm' num_labels: Optional[int] = None mlp_padding_free: bool = False + model_kwargs: Dict[str, Any] = field(default_factory=dict) _mindspeed_defaults_cache = None diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py index 6c61f09..36f9d81 100644 --- a/src/mcore_bridge/model/constant.py +++ b/src/mcore_bridge/model/constant.py @@ -27,6 +27,7 @@ class MLLMModelType: glm4v_moe = 'glm4v_moe' kimi_vl = 'kimi_vl' llama4 = 'llama4' + gemma4 = 'gemma4' kimi_k25 = 'kimi_k25' diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 1903b7c..5299a26 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -100,9 +100,7 @@ def __init__( for i in range(len(self.decoder.layers)): if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb - self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) - self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + self._set_inv_freq() if self.config.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, @@ -212,7 +210,36 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( + decoder_input, position_ids, packed_seq_params=packed_seq_params) + + if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') + or self.config.flash_decode) and rotary_pos_cos is not None + and inference_context.is_static_batching()): + current_batch_size = input_ids.shape[0] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if in_inference_mode and not has_config_logger_enabled(self.config): + decoder_input = WrappedTensor(decoder_input) + return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, + sequence_len_offset) + + def _set_inv_freq(self): + self.attention_scaling = 1. + new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config) + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None @@ -247,26 +274,13 @@ def _preprocess( rotary_seq_len, packed_seq=packed_seq, ) + decoder_rotary_pos_emb = rotary_pos_emb + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] - if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') - or self.config.flash_decode) and rotary_pos_cos is not None - and inference_context.is_static_batching()): - current_batch_size = input_ids.shape[0] - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) - - return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -298,7 +312,7 @@ def forward( inference_context = deprecate_inference_params(inference_context, inference_params) - decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( + decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, position_ids=position_ids, @@ -306,11 +320,6 @@ def forward( inference_context=inference_context, packed_seq_params=packed_seq_params, )) - decoder_rotary_pos_emb = rotary_pos_emb - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] mtp_decoder_input = decoder_input if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: diff --git a/src/mcore_bridge/model/mm_gpts/__init__.py b/src/mcore_bridge/model/mm_gpts/__init__.py index d13e4e7..b8ea385 100644 --- a/src/mcore_bridge/model/mm_gpts/__init__.py +++ b/src/mcore_bridge/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl +from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py new file mode 100644 index 0000000..8bfc0c2 --- /dev/null +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -0,0 +1,137 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.mlp import MLP +from transformers import AutoModel, PretrainedConfig +from typing import Optional + +from mcore_bridge.bridge import MultimodalGPTBridge +from mcore_bridge.config import ModelConfig + +from ..constant import ModelType +from ..gpt_model import GPTModel +from ..mm_gpt_model import MultimodalGPTModel +from ..register import ModelLoader, ModelMeta, register_model +from ..rope import get_rope_inv_freq +from .utils import HuggingFaceVit + + +class Gemma4Vit(HuggingFaceVit): + module_mapping = { + 'model.vision_tower': 'vision_tower', + 'model.embed_vision': 'embed_vision', + 'model.audio_tower': 'audio_tower', + 'model.embed_audio': 'embed_audio', + } + _vision_tower = ['vision_tower', 'audio_tower'] + _aligner = ['embed_vision', 'embed_audio'] + support_multimodal = False + + def prepare_model(self, hf_config: PretrainedConfig): + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + self.vision_tower = AutoModel.from_config(hf_config.vision_config) + self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) + self.embed_audio = ( + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) + if hf_config.audio_config is not None else None) + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + return inputs_embeds + + +class Gemma4SelfAttention(SelfAttention): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' + self.sliding_window = text_config.sliding_window if self.is_sliding else None + kv_channels = config.kv_channels + config.kv_channels = ( + text_config.global_head_dim + if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) + super().__init__(config, submodules, layer_number, *args, **kwargs) + config.kv_channels = kv_channels + + +class Gemma4MLP(MLP): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + self.layer_number = layer_number + text_config = config.hf_config.text_config + self.enable_moe_block = text_config.enable_moe_block + first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + super().__init__(config, submodules, *args, **kwargs) + + +class Gemma4Bridge(MultimodalGPTBridge): + pass + + +class Gemma4TextGPTModel(GPTModel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print() + + def _set_inv_freq(self): + rope_scaling = self.config.rope_scaling + self.config.rope_scaling = rope_scaling['sliding_attention'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + assert attention_scaling == 1, 'not support' + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + # full + self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) + self.config.rope_scaling = rope_scaling['full_attention'] + kwargs = {} + if self.config.rope_scaling['rope_type'] == 'proportional': + kwargs['head_dim_key'] = 'global_head_dim' + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs) + assert attention_scaling == 1, 'not support' + self.full_rotary_pos_emb.inv_freq = new_inv_freq + self.attention_scaling = attention_scaling + + self.config.rope_scaling = rope_scaling + + +class Gemma4GPTModel(MultimodalGPTModel): + language_model_cls = Gemma4TextGPTModel + + +class Gemma4Loader(ModelLoader): + model_cls = Gemma4GPTModel + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec in layer_specs.layer_specs: + layer_spec.submodules.self_attention.module = Gemma4SelfAttention + layer_spec.submodules.mlp.module = Gemma4MLP + return layer_specs + + +register_model( + ModelMeta( + ModelType.gemma4, + ['gemma4'], + bridge_cls=Gemma4Bridge, + visual_cls=Gemma4Vit, + loader=Gemma4Loader, + )) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 024135e..be07631 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -134,6 +134,7 @@ def __init__( from megatron.core.transformer.moe.moe_layer import MoELayer from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit @@ -148,6 +149,8 @@ def __init__( elif submodules.mlp.module in (MLP, Glm4MLP): assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif submodules.mlp.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = layer_number elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp