From 4e2df75448d01d28ae6a58906b38371ddad98a9d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 17:42:35 +0800 Subject: [PATCH 01/12] update --- src/mcore_bridge/config/model_config.py | 5 +++-- src/mcore_bridge/model/constant.py | 1 + src/mcore_bridge/model/mm_gpts/__init__.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 943c100..5eec352 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -3,13 +3,13 @@ import os import re import torch.nn.functional as F -from dataclasses import dataclass +from dataclasses import dataclass, field from megatron.core import mpu from megatron.core.transformer import TransformerConfig from transformers import PretrainedConfig from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version -from typing import List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mcore_bridge.utils import get_logger, json_parse_to_dict @@ -229,6 +229,7 @@ class ModelConfig(TransformerConfig): task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm' num_labels: Optional[int] = None mlp_padding_free: bool = False + model_kwargs: Dict[str, Any] = field(default_factory=dict) _mindspeed_defaults_cache = None diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py index 6c61f09..36f9d81 100644 --- a/src/mcore_bridge/model/constant.py +++ b/src/mcore_bridge/model/constant.py @@ -27,6 +27,7 @@ class MLLMModelType: glm4v_moe = 'glm4v_moe' kimi_vl = 'kimi_vl' llama4 = 'llama4' + gemma4 = 'gemma4' kimi_k25 = 'kimi_k25' diff --git a/src/mcore_bridge/model/mm_gpts/__init__.py b/src/mcore_bridge/model/mm_gpts/__init__.py index d13e4e7..b8ea385 100644 --- a/src/mcore_bridge/model/mm_gpts/__init__.py +++ b/src/mcore_bridge/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl +from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl From c388954d3a86fa9764226fed3b886df0af9497af Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 17:49:36 +0800 Subject: [PATCH 02/12] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 63 ++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/mcore_bridge/model/mm_gpts/gemma4.py diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py new file mode 100644 index 0000000..9b99428 --- /dev/null +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -0,0 +1,63 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from transformers import AutoModel, PretrainedConfig + +from mcore_bridge.bridge import GPTBridge + +from ..constant import ModelType +from ..register import ModelLoader, ModelMeta, register_model +from .utils import HuggingFaceVit + + +class Gemma4Vit(HuggingFaceVit): + module_mapping = { + 'model.vision_tower': 'vision_tower', + 'model.embed_vision': 'embed_vision', + 'model.audio_tower': 'audio_tower', + 'model.embed_audio': 'embed_audio', + } + _vision_tower = ['vision_tower', 'audio_tower'] + _aligner = ['embed_vision', 'embed_audio'] + support_multimodal = False + + def prepare_model(self, hf_config: PretrainedConfig): + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + self.vision_tower = AutoModel.from_config(hf_config.vision_config) + self.vocab_size = hf_config.text_config.vocab_size + + language_model = AutoModel.from_config(config=hf_config.text_config) + self.language_model = language_model + self.vocab_size_per_layer_input = hf_config.text_config.vocab_size_per_layer_input + self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None + self.embed_vision = ( + Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) + if hf_config.vision_config is not None else None) + self.embed_audio = ( + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) + if hf_config.audio_config is not None else None) + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + return inputs_embeds + + +class Gemma4Bridge(GPTBridge): + pass + + +class Gemma4Loader(ModelLoader): + pass + # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + # layer_specs = get_gpt_decoder_block_spec( + # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + # for layer_spec in layer_specs.layer_specs: + # pass + # return layer_specs + + +register_model( + ModelMeta( + ModelType.gemma4, + ['gemma4'], + bridge_cls=Gemma4Bridge, + visual_cls=Gemma4Vit, + loader=Gemma4Loader, + )) From 76af2bcebdf2c1011951a219a4bcda87103185ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 19:19:39 +0800 Subject: [PATCH 03/12] update --- src/mcore_bridge/model/gpt_model.py | 65 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 5fb714c..e34c60a 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -110,9 +110,7 @@ def __init__( for i in range(len(self.decoder.layers)): if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb - self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) - self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + self._set_inv_freq() if self.config.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, @@ -222,7 +220,36 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + rotary_pos_emb, rotary_pos_cos, decoder_rotary_pos_emb, rotary_pos_sin = self._get_rotary_pos_emb( + decoder_input, position_ids, packed_seq_params=packed_seq_params) + + if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') + or self.config.flash_decode) and rotary_pos_cos is not None + and inference_context.is_static_batching()): + current_batch_size = input_ids.shape[0] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if in_inference_mode and not has_config_logger_enabled(self.config): + decoder_input = WrappedTensor(decoder_input) + return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, + sequence_len_offset) + + def _set_inv_freq(self): + self.attention_scaling = 1. + new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config) + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None @@ -257,26 +284,13 @@ def _preprocess( rotary_seq_len, packed_seq=packed_seq, ) + decoder_rotary_pos_emb = rotary_pos_emb + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] - if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') - or self.config.flash_decode) and rotary_pos_cos is not None - and inference_context.is_static_batching()): - current_batch_size = input_ids.shape[0] - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) - - return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -308,7 +322,7 @@ def forward( inference_context = deprecate_inference_params(inference_context, inference_params) - decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( + decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, position_ids=position_ids, @@ -316,11 +330,6 @@ def forward( inference_context=inference_context, packed_seq_params=packed_seq_params, )) - decoder_rotary_pos_emb = rotary_pos_emb - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] mtp_decoder_input = decoder_input if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: From 54e33435585c2bdd1f5045c162b238e98f0565ba Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 21:47:11 +0800 Subject: [PATCH 04/12] update --- src/mcore_bridge/tuners/patcher.py | 39 +++--------------------------- src/mcore_bridge/utils/__init__.py | 2 +- src/mcore_bridge/utils/utils.py | 36 +++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/src/mcore_bridge/tuners/patcher.py b/src/mcore_bridge/tuners/patcher.py index e715c35..f9cae8c 100644 --- a/src/mcore_bridge/tuners/patcher.py +++ b/src/mcore_bridge/tuners/patcher.py @@ -1,6 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import copy -from contextlib import contextmanager from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.router import TopKRouter @@ -11,6 +9,8 @@ from torch import nn from typing import Optional +from mcore_bridge.utils import patch_deepcopy + from .lora import LoraParallelLinear @@ -37,39 +37,6 @@ def dispatch_megatron( model.dispatch_megatron = dispatch_megatron -@contextmanager -def _patch_deepcopy(): - _origin_deepcopy = copy.deepcopy - copy_keys = ('tp_group', '_tp_group', 'config') - - def new_deepcopy(x, *args, **kwargs): - if not isinstance(x, nn.Module): - return _origin_deepcopy(x, *args, **kwargs) - - saved_values = {} - for key in copy_keys: - val = getattr(x, key, None) - if val is not None: - saved_values[key] = val - setattr(x, key, None) - - try: - res = _origin_deepcopy(x, *args, **kwargs) - finally: - for key, value in saved_values.items(): - setattr(x, key, value) - - for key, value in saved_values.items(): - setattr(res, key, value) - return res - - copy.deepcopy = new_deepcopy - try: - yield - finally: - copy.deepcopy = _origin_deepcopy - - def _patch_lora_model(): if hasattr(LoraModel, '_mcore_patched'): return @@ -77,7 +44,7 @@ def _patch_lora_model(): __origin_init__ = LoraModel.__init__ def __new_init__(self, *args, **kwargs): - with _patch_deepcopy(): + with patch_deepcopy(): __origin_init__(self, *args, **kwargs) if not isinstance(self.model, MegatronModule): return diff --git a/src/mcore_bridge/utils/__init__.py b/src/mcore_bridge/utils/__init__.py index 34cdbd1..d4285be 100644 --- a/src/mcore_bridge/utils/__init__.py +++ b/src/mcore_bridge/utils/__init__.py @@ -6,4 +6,4 @@ from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device -from .utils import deep_getattr, get_env_args, json_parse_to_dict +from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy diff --git a/src/mcore_bridge/utils/utils.py b/src/mcore_bridge/utils/utils.py index 6905c41..a7e525b 100644 --- a/src/mcore_bridge/utils/utils.py +++ b/src/mcore_bridge/utils/utils.py @@ -1,5 +1,8 @@ +import copy import json import os +from contextlib import contextmanager +from torch import nn from transformers.utils import strtobool from typing import Callable, Dict, Optional, TypeVar, Union @@ -58,3 +61,36 @@ def deep_getattr(obj, attr: str, default=None): else: obj = getattr(obj, a, default) return obj + + +@contextmanager +def patch_deepcopy(): + _origin_deepcopy = copy.deepcopy + copy_keys = ('tp_group', '_tp_group', 'config') + + def new_deepcopy(x, *args, **kwargs): + if not isinstance(x, nn.Module): + return _origin_deepcopy(x, *args, **kwargs) + + saved_values = {} + for key in copy_keys: + val = getattr(x, key, None) + if val is not None: + saved_values[key] = val + setattr(x, key, None) + + try: + res = _origin_deepcopy(x, *args, **kwargs) + finally: + for key, value in saved_values.items(): + setattr(x, key, value) + + for key, value in saved_values.items(): + setattr(res, key, value) + return res + + copy.deepcopy = new_deepcopy + try: + yield + finally: + copy.deepcopy = _origin_deepcopy From 25a45bda3aa7dde0b9125c33614f94bdc0c4c18e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Apr 2026 10:13:18 +0800 Subject: [PATCH 05/12] update --- src/mcore_bridge/model/gpt_model.py | 2 +- src/mcore_bridge/model/mm_gpt_model.py | 4 ++- src/mcore_bridge/model/mm_gpts/gemma4.py | 41 ++++++++++++++++++------ src/mcore_bridge/model/rope.py | 4 +-- 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index e34c60a..ace31e4 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -220,7 +220,7 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) - rotary_pos_emb, rotary_pos_cos, decoder_rotary_pos_emb, rotary_pos_sin = self._get_rotary_pos_emb( + rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( decoder_input, position_ids, packed_seq_params=packed_seq_params) if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index b68e82b..b3fc0d6 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -18,6 +18,7 @@ class MultimodalGPTModel(MegatronModule): + language_model_cls = GPTModel def __init__(self, config: ModelConfig, @@ -29,7 +30,8 @@ def __init__(self, super().__init__(config) self.pre_process = pre_process self.post_process = post_process - self.language_model = GPTModel(config, transformer_layer_spec, pre_process, post_process, *_args, **kwargs) + self.language_model = self.language_model_cls(config, transformer_layer_spec, pre_process, post_process, *_args, + **kwargs) self.vp_stage = self.language_model.vp_stage self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights self.model_meta = config.model_meta diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 9b99428..2d656dd 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,10 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy from transformers import AutoModel, PretrainedConfig from mcore_bridge.bridge import GPTBridge from ..constant import ModelType +from ..gpt_model import GPTModel +from ..mm_gpt_model import MultimodalGPTModel from ..register import ModelLoader, ModelMeta, register_model +from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit @@ -22,15 +26,8 @@ class Gemma4Vit(HuggingFaceVit): def prepare_model(self, hf_config: PretrainedConfig): from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder self.vision_tower = AutoModel.from_config(hf_config.vision_config) - self.vocab_size = hf_config.text_config.vocab_size - - language_model = AutoModel.from_config(config=hf_config.text_config) - self.language_model = language_model - self.vocab_size_per_layer_input = hf_config.text_config.vocab_size_per_layer_input self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None - self.embed_vision = ( - Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) - if hf_config.vision_config is not None else None) + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) self.embed_audio = ( Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) if hf_config.audio_config is not None else None) @@ -43,8 +40,34 @@ class Gemma4Bridge(GPTBridge): pass +class Gemma4TextGPTModel(GPTModel): + + def _set_inv_freq(self): + rope_scaling = self.config.rope_scaling + self.config.rope_scaling = rope_scaling['sliding_attention'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + assert attention_scaling == 1, 'not support' + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + # full + self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) + self.config.rope_scaling = rope_scaling['full_attention'] + kwargs = {} + if self.config.rope_scaling['rope_type'] == 'proportional': + kwargs['head_dim_key'] = 'global_head_dim' + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs) + assert attention_scaling == 1, 'not support' + self.full_rotary_pos_emb.inv_freq = new_inv_freq + self.attention_scaling = attention_scaling + + self.config.rope_scaling = rope_scaling + + +class Gemma4GPTModel(MultimodalGPTModel): + language_model_cls = Gemma4TextGPTModel + + class Gemma4Loader(ModelLoader): - pass + model_cls = Gemma4GPTModel # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): # layer_specs = get_gpt_decoder_block_spec( # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py index 5cabe42..e7db3c3 100644 --- a/src/mcore_bridge/model/rope.py +++ b/src/mcore_bridge/model/rope.py @@ -106,12 +106,12 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(config, seq_len=None): +def get_rope_inv_freq(config, seq_len=None, **kwargs): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) dummy_config = _get_dummy_config(config) rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)] - inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len) + inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs) if attention_scaling is None: attention_scaling = 1. return inv_freq, attention_scaling From 32106178dc2476f09e2589262b8c5b3a28b70dce Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Apr 2026 11:30:56 +0800 Subject: [PATCH 06/12] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 39 +++++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 2d656dd..b2a7fde 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,8 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from transformers import AutoModel, PretrainedConfig +from typing import Optional -from mcore_bridge.bridge import GPTBridge +from mcore_bridge.bridge import MultimodalGPTBridge +from mcore_bridge.config import ModelConfig from ..constant import ModelType from ..gpt_model import GPTModel @@ -36,12 +40,30 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class Gemma4Bridge(GPTBridge): +class Gemma4SelfAttention(SelfAttention): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + super().__init__(config, submodules, layer_number, *args, **kwargs) + + +class Gemma4Bridge(MultimodalGPTBridge): pass class Gemma4TextGPTModel(GPTModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print() + def _set_inv_freq(self): rope_scaling = self.config.rope_scaling self.config.rope_scaling = rope_scaling['sliding_attention'] @@ -68,12 +90,13 @@ class Gemma4GPTModel(MultimodalGPTModel): class Gemma4Loader(ModelLoader): model_cls = Gemma4GPTModel - # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - # layer_specs = get_gpt_decoder_block_spec( - # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) - # for layer_spec in layer_specs.layer_specs: - # pass - # return layer_specs + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec in layer_specs.layer_specs: + layer_spec.submodules.self_attention.module = Gemma4SelfAttention + return layer_specs register_model( From 5b4e118bc804475f58a057a7805f7ff6f312ca4d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 4 May 2026 18:15:41 +0800 Subject: [PATCH 07/12] update --- src/mcore_bridge/model/modules/__init__.py | 1 + .../model/modules/transformer_layer.py | 30 +++++++++++++++++++ src/mcore_bridge/patcher.py | 30 ------------------- 3 files changed, 31 insertions(+), 30 deletions(-) create mode 100644 src/mcore_bridge/model/modules/transformer_layer.py diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 6fd1ac7..eff1bd6 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -2,3 +2,4 @@ from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer +from .transformer_layer import CustomTransformerLayer diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py new file mode 100644 index 0000000..55aa952 --- /dev/null +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -0,0 +1,30 @@ +import megatron.core +from megatron.core.transformer import TransformerLayer +from packaging import version + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + + +class CustomTransformerLayer(TransformerLayer): + + def forward(self, *args, **kwargs): + """ + Perform a forward pass through the transformer layer. + + This method calls the core computation of a transformer layer, including + self-attention, cross-attention (if applicable), and feed-forward operations. + """ + if not mcore_013: + return super().forward(self, *args, **kwargs) + hidden_states, context = self._forward_attention(*args, **kwargs) + mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs + mask = None + if mlp_padding_free and hidden_states.shape[1] > 1: + mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() + hidden_states = hidden_states[mask][:, None] + output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) + if mask is not None: + new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) + new_output[mask] = output.squeeze(1) + output = new_output + return output, context diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 15280b5..527b0c5 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -13,7 +13,6 @@ from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region) -from megatron.core.transformer import TransformerLayer from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock, get_mtp_layer_offset from megatron.core.utils import deprecate_inference_params @@ -413,34 +412,6 @@ def sharded_state_dict( peft_module.OriginModulesToSaveWrapper = OriginModulesToSaveWrapper -def _patch_TransformerLayer(): - _origin_forward = TransformerLayer.forward - - def forward(self, *_args, **kwargs): - """ - Perform a forward pass through the transformer layer. - - This method calls the core computation of a transformer layer, including - self-attention, cross-attention (if applicable), and feed-forward operations. - """ - if not mcore_013: - return _origin_forward(self, *_args, **kwargs) - hidden_states, context = self._forward_attention(*_args, **kwargs) - mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs - mask = None - if mlp_padding_free and hidden_states.shape[1] > 1: - mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() - hidden_states = hidden_states[mask][:, None] - output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) - if mask is not None: - new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) - new_output[mask] = output.squeeze(1) - output = new_output - return output, context - - TransformerLayer.forward = forward - - def _patch_TELinear(): def __repr__(self): @@ -769,7 +740,6 @@ def apply_patch(): # patch module _patch_mla_attention() _patch_TEGroupedLinear() - _patch_TransformerLayer() _patch_TELinear() _patch_mrope() _patch_mtp() From d1d22462c3e56e89ac34f06aafeb484243cfd78a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 00:07:45 +0800 Subject: [PATCH 08/12] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index b2a7fde..5b09745 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -4,7 +4,7 @@ from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from transformers import AutoModel, PretrainedConfig from typing import Optional - +from megatron.core.transformer.mlp import MLP from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.config import ModelConfig @@ -51,7 +51,27 @@ def __init__( **kwargs, ): text_config = config.hf_config.text_config + self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' + self.sliding_window = text_config.sliding_window if self.is_sliding else None + kv_channels = config.kv_channels + config.kv_channels = text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim super().__init__(config, submodules, layer_number, *args, **kwargs) + config.kv_channels = kv_channels + +class Gemma4MLP(MLP): + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + self.enable_moe_block = text_config.enable_moe_block + first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + super().__init__(config, submodules, *args, **kwargs) class Gemma4Bridge(MultimodalGPTBridge): @@ -96,6 +116,7 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) for layer_spec in layer_specs.layer_specs: layer_spec.submodules.self_attention.module = Gemma4SelfAttention + layer_spec.submodules.mlp.module = Gemma4MLP return layer_specs From 14b164489b9228b73276b3f6b3560984e483baac Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 12:19:34 +0800 Subject: [PATCH 09/12] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 11 +- .../model/modules/transformer_layer.py | 208 +++++++++++++++++- src/mcore_bridge/model/register.py | 6 +- 3 files changed, 221 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 5b09745..8cba526 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -2,9 +2,10 @@ import copy from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.mlp import MLP from transformers import AutoModel, PretrainedConfig from typing import Optional -from megatron.core.transformer.mlp import MLP + from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.config import ModelConfig @@ -54,18 +55,24 @@ def __init__( self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' self.sliding_window = text_config.sliding_window if self.is_sliding else None kv_channels = config.kv_channels - config.kv_channels = text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim + config.kv_channels = ( + text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim + ) super().__init__(config, submodules, layer_number, *args, **kwargs) config.kv_channels = kv_channels + class Gemma4MLP(MLP): + def __init__( self, config: ModelConfig, submodules: SelfAttentionSubmodules, + layer_number: int, *args, **kwargs, ): + self.layer_number = layer_number text_config = config.hf_config.text_config self.enable_moe_block = text_config.enable_moe_block first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 55aa952..83c42e3 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,12 +1,218 @@ import megatron.core -from megatron.core.transformer import TransformerLayer +import torch +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import CudaGraphScope, LayerType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, + get_transformer_layer_offset) +from megatron.core.utils import get_pg_rank from packaging import version +from typing import Optional + +from mcore_bridge.utils import get_logger mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +logger = get_logger() + class CustomTransformerLayer(TransformerLayer): + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: Optional[float] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, + ): + self.submodules_config = submodules + super().__init__(config=config, vp_stage=vp_stage) + + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + self.pg_collection = pg_collection + self.tp_group = pg_collection.tp + + # MTP inner layers use their own layer numbering (starting from 1 within each MTP depth), + # so they should NOT add the decoder layer offset. The router.py handles MTP layer + # numbering separately by adding config.num_layers to distinguish MTP layers from decoder + # layers in the aux loss tracker. + # + # When add_layer_offset is False, the caller has already included the correct offset + # in layer_number (e.g. when using --hybrid-layer-pattern with fVPP). + if is_mtp_layer or not add_layer_offset: + self.layer_number = layer_number + else: + self.layer_number = layer_number + get_transformer_layer_offset(self.config, vp_stage, + get_pg_rank(pg_collection.pp)) + self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout + self.is_mtp_layer = is_mtp_layer + + # [Module 1: Input Layernorm] Optional Layernorm on the input data + # TODO: add pytorch only layernorm + self.input_layernorm = submodules.input_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + attention_optional_kwargs = {} + if config.context_parallel_size > 1 and config.cp_comm_type is not None: + if isinstance(config.cp_comm_type, list): + # layer_number is 1-indexed, so we need to subtract 1 to get the correct index + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type[self.layer_number - 1] + else: + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type + + attention_optional_kwargs['pg_collection'] = pg_collection + if pp_layer_offset is not None: + attention_optional_kwargs['pp_layer_offset'] = pp_layer_offset + + # [Module 2: SelfAttention] + self.self_attention = build_module( + submodules.self_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # [Module 5: CrossAttention] + self.cross_attention = build_module( + submodules.cross_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 6: BiasDropoutFusion] + self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) + + # [Module 7: Pre MLP] Optional Layernorm before MLP + self.pre_mlp_layernorm = submodules.pre_mlp_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + # [Module 8: MLP block] + additional_mlp_kwargs = {} + # import here to avoid circular import + from megatron.core.extensions.transformer_engine import TEFusedMLP + from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP + from megatron.core.transformer.moe.moe_layer import MoELayer + + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. + # We can change MLP to accept pg_collection but it makes the logic implicit + # The conditional below is to make the logic explicit + # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs + if isinstance(submodules.mlp, ModuleSpec): + if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): + additional_mlp_kwargs['pg_collection'] = pg_collection + # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. + if submodules.mlp.module == MoELayer: + additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer + elif submodules.mlp.module == MLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + else: + logger.warning_once(f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.") + self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + if hasattr(self.mlp, 'set_layer_number'): + self.mlp.set_layer_number(self.layer_number) + + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False + if self.config.recompute_granularity == 'selective': + assert self.config.recompute_modules is not None + if 'layernorm' in self.config.recompute_modules: + if not isinstance(self.input_layernorm, IdentityOp): + self.recompute_input_layernorm = True + if self.config.fp8 or self.config.fp4: + self.self_attention.set_for_recompute_input_layernorm() + + def can_recompute_pre_mlp_layernorm_for_cudagraph(): + if (not self.is_moe_layer or CudaGraphScope.moe_router not in self.config.cuda_graph_scope + or self.config.cuda_graph_impl == 'local'): + # Not a MoE layer, or not capturing the router part. + return True + if (self.config.moe_shared_expert_intermediate_size is not None + and self.config.moe_shared_expert_overlap): + # If shared expert overlap is used, we cannot make the pre-mlp layernorm + # recomputation, because the shared expert takes the layernorm output as + # input, and it is outside of the CUDA graph scope. + logger.warning( + 'pre_mlp_layernorm recompute is not supported with moe router ' + 'cudagraph + shared expert overlap. Disabling pre_mlp_layernorm ' + 'recompute.', ) + return False + if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and ( + self.config.moe_token_dispatcher_type == 'alltoall' or self.config.moe_latent_size): + # Only when capturing the preprocess part and using alltoall token + # dispatcher or latent MoE can we make the pre-mlp layernorm recomputation. + # Because in other cases the layernorm output returns directly as one of the + # outputs of the cudagraph, which will be allocated a static buffer, thus + # not able to be released. + return True + logger.warning( + 'pre_mlp_layernorm recompute is only supported with moe router + ' + 'preprocess cudagraph will alltoall token dispatcher or latent MoE. ' + 'Disabling pre_mlp_layernorm recompute.', ) + return False + + if (not isinstance(self.pre_mlp_layernorm, IdentityOp) + and can_recompute_pre_mlp_layernorm_for_cudagraph()): + self.recompute_pre_mlp_layernorm = True + if self.config.fp8 or self.config.fp4: + if isinstance(self.mlp, MoELayer): + self.mlp.set_for_recompute_pre_mlp_layernorm() + else: + from megatron.core.extensions.transformer_engine import set_save_original_input + + set_save_original_input(self.mlp.linear_fc1) + if 'mlp' in self.config.recompute_modules: + if not self.is_moe_layer: + self.recompute_mlp = True + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + + # @jcasper how should we handle nvfuser? + # Set bias+dropout+add fusion grad_enable execution handler. + # TORCH_MAJOR = int(torch.__version__.split('.')[0]) + # TORCH_MINOR = int(torch.__version__.split('.')[1]) + # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad + self.bias_dropout_add_exec_handler = torch.enable_grad + def forward(self, *args, **kwargs): """ Perform a forward pass through the transformer layer. diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 0e67e90..9be5b6f 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -15,7 +15,7 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger -from .modules import MultiTokenPredictionLayer +from .modules import CustomTransformerLayer, MultiTokenPredictionLayer if TYPE_CHECKING: from .gpt_model import GPTModel @@ -138,6 +138,10 @@ def _set_shared_expert_gate(self, transformer_layer_spec): if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + def _set_custom_layer(self, transformer_layer_spec): + pass + # CustomTransformerLayer + def build_model( self, pre_process=True, From 196a58fda42452eebe2eabee2f2545f0099e2122 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 13:19:36 +0800 Subject: [PATCH 10/12] update --- src/mcore_bridge/model/gpts/glm4.py | 10 ++-- src/mcore_bridge/model/gpts/minimax_m2.py | 7 +-- src/mcore_bridge/model/mm_gpts/gemma4.py | 4 +- .../model/modules/transformer_layer.py | 45 +++++++++++++----- src/mcore_bridge/model/register.py | 46 ++++++------------- 5 files changed, 58 insertions(+), 54 deletions(-) diff --git a/src/mcore_bridge/model/gpts/glm4.py b/src/mcore_bridge/model/gpts/glm4.py index 5c7a8ca..861f94d 100644 --- a/src/mcore_bridge/model/gpts/glm4.py +++ b/src/mcore_bridge/model/gpts/glm4.py @@ -91,11 +91,11 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo class Glm4Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = Glm4SelfAttention - layer_spec.submodules.mlp.module = Glm4MLP - transformer_layer.MLP = Glm4MLP # patch - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Glm4SelfAttention + layer_spec.submodules.mlp.module = Glm4MLP + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index 81b11b6..c03f803 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -95,9 +95,10 @@ def _set_moe_state( class MinimaxM2Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 8cba526..8bfc0c2 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -56,8 +56,8 @@ def __init__( self.sliding_window = text_config.sliding_window if self.is_sliding else None kv_channels = config.kv_channels config.kv_channels = ( - text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim - ) + text_config.global_head_dim + if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) super().__init__(config, submodules, layer_number, *args, **kwargs) config.kv_channels = kv_channels diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 83c42e3..4fe0a82 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,7 +1,8 @@ +import enum +import inspect import megatron.core import torch from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.enums import CudaGraphScope, LayerType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -14,6 +15,22 @@ from mcore_bridge.utils import get_logger +try: + from megatron.core.transformer.enums import CudaGraphScope +except ImportError: + + class CudaGraphScope(enum.Enum): + """Cuda Graph Scope - defines which parts of the model to capture.""" + + full_iteration = 1 # Captures the entire training/inference iteration + attn = 2 # Captures attention layers + mlp = 3 # Captures MLP layers (dense layers only) + moe = 4 # Captures MoE layers (drop-and-pad MoE layers only) + moe_router = 5 # Captures MoE router part + moe_preprocess = 6 # Captures MoE preprocessing part (requires moe_router) + mamba = 7 # Captures Mamba layers + + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') logger = get_logger() @@ -34,7 +51,7 @@ def __init__( pp_layer_offset: Optional[int] = None, ): self.submodules_config = submodules - super().__init__(config=config, vp_stage=vp_stage) + super(TransformerLayer, self).__init__(config=config, vp_stage=vp_stage) if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() @@ -118,6 +135,9 @@ def __init__( from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP from megatron.core.transformer.moe.moe_layer import MoELayer + from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit # The conditional below is to make the logic explicit @@ -126,16 +146,18 @@ def __init__( if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): additional_mlp_kwargs['pg_collection'] = pg_collection # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. - if submodules.mlp.module == MoELayer: + if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer - elif submodules.mlp.module == MLP: + elif submodules.mlp.module in (MLP, Glm4MLP): assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif submodules.mlp.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = layer_number elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp else: - logger.warning_once(f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.") + logger.warning_once(f'Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.') self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) @@ -198,12 +220,13 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): if 'mlp' in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True - self.offload_attn_norm = ( - self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp)) - self.offload_mlp_norm = ( - self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + if hasattr(self.config, 'fine_grained_activation_offloading'): + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 9be5b6f..e8eef7d 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -90,41 +90,20 @@ def _replace_spec_dsa(self, layer_spec): layer_spec.submodules.self_attention = dsa_spec def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - if self.config.num_moe_experts: - transformer_layer_spec = get_gpt_decoder_block_spec( - self.config, - use_transformer_engine=True, - normalization=self.config.normalization, - qk_l2_norm=self.config.qk_l2_norm, - vp_stage=vp_stage) - if self.config.experimental_attention_variant == 'dsa': - for layer_spec in transformer_layer_spec.layer_specs: - self._replace_spec_dsa(layer_spec) - else: - transformer_layer_spec = self._get_transformer_layer_spec() - return transformer_layer_spec - - def _get_transformer_layer_spec(self): - config = self.config - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - config.num_moe_experts, - config.moe_grouped_gemm, - config.qk_layernorm, - config.multi_latent_attention, - qk_l2_norm=config.qk_l2_norm, - ) + transformer_layer_spec = get_gpt_decoder_block_spec( + self.config, + use_transformer_engine=True, + normalization=self.config.normalization, + qk_l2_norm=self.config.qk_l2_norm, + vp_stage=vp_stage) + if self.config.experimental_attention_variant == 'dsa': + for layer_spec in transformer_layer_spec.layer_specs: + self._replace_spec_dsa(layer_spec) return transformer_layer_spec def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = None): - if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - # TODO: remove - transformer_layer_spec_for_mtp = self._get_transformer_layer_spec() - else: - transformer_layer_spec_for_mtp = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec( - self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, vp_stage=vp_stage) + self.config, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage) if mtp_block_spec is not None: for layer_spec in mtp_block_spec.layer_specs: layer_spec.module = MultiTokenPredictionLayer @@ -139,8 +118,8 @@ def _set_shared_expert_gate(self, transformer_layer_spec): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} def _set_custom_layer(self, transformer_layer_spec): - pass - # CustomTransformerLayer + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = CustomTransformerLayer def build_model( self, @@ -150,6 +129,7 @@ def build_model( ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) + self._set_custom_layer(transformer_layer_spec) mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) From 68e33a7dc04b3b8a746993802f98b5437ba705ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 13:31:46 +0800 Subject: [PATCH 11/12] update --- src/mcore_bridge/model/modules/transformer_layer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 4fe0a82..83dcf5e 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,6 +1,5 @@ import enum import inspect -import megatron.core import torch from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.identity_op import IdentityOp @@ -10,7 +9,6 @@ from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset) from megatron.core.utils import get_pg_rank -from packaging import version from typing import Optional from mcore_bridge.utils import get_logger @@ -31,8 +29,6 @@ class CudaGraphScope(enum.Enum): mamba = 7 # Captures Mamba layers -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') - logger = get_logger() @@ -243,8 +239,6 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ - if not mcore_013: - return super().forward(self, *args, **kwargs) hidden_states, context = self._forward_attention(*args, **kwargs) mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs mask = None From 44ddaec8fec777dbda3b627c5718694d7d1bb9a8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 20:31:15 +0800 Subject: [PATCH 12/12] update --- .../model/modules/transformer_layer.py | 22 ++++++++++++++++++- src/mcore_bridge/model/register.py | 4 +--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 83dcf5e..11c940b 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -2,6 +2,8 @@ import inspect import torch from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, + scatter_to_sequence_parallel_region) from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -242,12 +244,30 @@ def forward(self, *args, **kwargs): hidden_states, context = self._forward_attention(*args, **kwargs) mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs mask = None + enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 + pad_size = 0 if mlp_padding_free and hidden_states.shape[1] > 1: + if enable_sp: + hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() hidden_states = hidden_states[mask][:, None] + if enable_sp: + tp_size = self.config.tensor_model_parallel_size + num_tokens = hidden_states.shape[0] + remainder = num_tokens % tp_size + if remainder != 0: + pad_size = tp_size - remainder + hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, 0, 0, pad_size)) + hidden_states = scatter_to_sequence_parallel_region(hidden_states) output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) if mask is not None: - new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) + if enable_sp: + output = gather_from_sequence_parallel_region(output, tensor_parallel_output_grad=False) + if pad_size > 0: + output = output[:-pad_size] + new_output = output.new_zeros((*mask.shape, output.shape[-1])) new_output[mask] = output.squeeze(1) output = new_output + if enable_sp: + output = scatter_to_sequence_parallel_region(output) return output, context diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index e8eef7d..15b37fe 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -4,9 +4,7 @@ from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear -from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, - get_gpt_layer_with_transformer_engine_spec, - get_gpt_mtp_block_spec) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec from packaging import version from torch import nn from typing import TYPE_CHECKING, List, Optional, Type, Union