Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/mcore_bridge/model/gpts/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo
class Glm4Loader(ModelLoader):

def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
layer_spec = self._get_transformer_layer_spec()
layer_spec.submodules.self_attention.module = Glm4SelfAttention
layer_spec.submodules.mlp.module = Glm4MLP
transformer_layer.MLP = Glm4MLP # patch
return layer_spec
transformer_layer_spec = super().get_transformer_layer_spec(vp_stage)
for layer_spec in transformer_layer_spec.layer_specs:
layer_spec.submodules.self_attention.module = Glm4SelfAttention
layer_spec.submodules.mlp.module = Glm4MLP
return transformer_layer_spec


register_model(ModelMeta(
Expand Down
7 changes: 4 additions & 3 deletions src/mcore_bridge/model/gpts/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ def _set_moe_state(
class MinimaxM2Loader(ModelLoader):

def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
layer_spec = self._get_transformer_layer_spec()
layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention
return layer_spec
transformer_layer_spec = super().get_transformer_layer_spec(vp_stage)
for layer_spec in transformer_layer_spec.layer_specs:
layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention
return transformer_layer_spec


register_model(ModelMeta(
Expand Down
5 changes: 3 additions & 2 deletions src/mcore_bridge/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import megatron.core
import torch
from contextlib import contextmanager
from megatron.core import InferenceParams
Expand All @@ -15,6 +14,7 @@


class MultimodalGPTModel(MegatronModule):
language_model_cls = GPTModel

def __init__(self,
config: ModelConfig,
Expand All @@ -26,7 +26,8 @@ def __init__(self,
super().__init__(config)
self.pre_process = pre_process
self.post_process = post_process
self.language_model = GPTModel(config, transformer_layer_spec, pre_process, post_process, *_args, **kwargs)
self.language_model = self.language_model_cls(config, transformer_layer_spec, pre_process, post_process, *_args,
**kwargs)
self.vp_stage = self.language_model.vp_stage
self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights
self.model_meta = config.model_meta
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .gated_delta_net import GatedDeltaNet
from .gated_self_attention import GatedSelfAttention
from .mtp_layer import MultiTokenPredictionLayer
from .transformer_layer import CustomTransformerLayer
271 changes: 271 additions & 0 deletions src/mcore_bridge/model/modules/transformer_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
import enum
import inspect
import torch
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region)
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules,
get_transformer_layer_offset)
from megatron.core.utils import get_pg_rank
from typing import Optional

from mcore_bridge.utils import get_logger

try:
from megatron.core.transformer.enums import CudaGraphScope
except ImportError:

class CudaGraphScope(enum.Enum):
"""Cuda Graph Scope - defines which parts of the model to capture."""

full_iteration = 1 # Captures the entire training/inference iteration
attn = 2 # Captures attention layers
mlp = 3 # Captures MLP layers (dense layers only)
moe = 4 # Captures MoE layers (drop-and-pad MoE layers only)
moe_router = 5 # Captures MoE router part
moe_preprocess = 6 # Captures MoE preprocessing part (requires moe_router)
mamba = 7 # Captures Mamba layers


logger = get_logger()


class CustomTransformerLayer(TransformerLayer):

def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: Optional[float] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
is_mtp_layer: bool = False,
add_layer_offset: bool = True,
pp_layer_offset: Optional[int] = None,
):
self.submodules_config = submodules
super(TransformerLayer, self).__init__(config=config, vp_stage=vp_stage)
Comment thread
Jintao-Huang marked this conversation as resolved.
Comment thread
Jintao-Huang marked this conversation as resolved.

if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
self.tp_group = pg_collection.tp

# MTP inner layers use their own layer numbering (starting from 1 within each MTP depth),
# so they should NOT add the decoder layer offset. The router.py handles MTP layer
# numbering separately by adding config.num_layers to distinguish MTP layers from decoder
# layers in the aux loss tracker.
#
# When add_layer_offset is False, the caller has already included the correct offset
# in layer_number (e.g. when using --hybrid-layer-pattern with fVPP).
if is_mtp_layer or not add_layer_offset:
self.layer_number = layer_number
else:
self.layer_number = layer_number + get_transformer_layer_offset(self.config, vp_stage,
get_pg_rank(pg_collection.pp))
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
self.is_mtp_layer = is_mtp_layer

# [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = submodules.input_layernorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)

attention_optional_kwargs = {}
if config.context_parallel_size > 1 and config.cp_comm_type is not None:
if isinstance(config.cp_comm_type, list):
# layer_number is 1-indexed, so we need to subtract 1 to get the correct index
attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type[self.layer_number - 1]
else:
attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type

attention_optional_kwargs['pg_collection'] = pg_collection
if pp_layer_offset is not None:
attention_optional_kwargs['pp_layer_offset'] = pp_layer_offset

# [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention,
config=self.config,
layer_number=self.layer_number,
**attention_optional_kwargs,
)

# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)

# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)

# [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention,
config=self.config,
layer_number=self.layer_number,
**attention_optional_kwargs,
)

# [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config)

# [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = submodules.pre_mlp_layernorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 8: MLP block]
additional_mlp_kwargs = {}
# import here to avoid circular import
from megatron.core.extensions.transformer_engine import TEFusedMLP
from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.moe_layer import MoELayer

from mcore_bridge.model.gpts.glm4 import Glm4MLP

# MLP expects tp_group but MoELayer expects pg_collection to be passed in.
# We can change MLP to accept pg_collection but it makes the logic implicit
# The conditional below is to make the logic explicit
# if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs
if isinstance(submodules.mlp, ModuleSpec):
if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP):
additional_mlp_kwargs['pg_collection'] = pg_collection
# Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers.
if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters:
additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer
elif submodules.mlp.module in (MLP, Glm4MLP):
assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer'
additional_mlp_kwargs['tp_group'] = pg_collection.tp
elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP:
assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer'
additional_mlp_kwargs['tp_group'] = pg_collection.tp
else:
logger.warning_once(f'Unknown MLP type: {submodules.mlp.module}. Using default kwargs.')
self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs)
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)

# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)

self.is_moe_layer = isinstance(self.mlp, MoELayer)

self.recompute_input_layernorm = False
self.recompute_pre_mlp_layernorm = False
self.recompute_mlp = False
if self.config.recompute_granularity == 'selective':
assert self.config.recompute_modules is not None
if 'layernorm' in self.config.recompute_modules:
if not isinstance(self.input_layernorm, IdentityOp):
self.recompute_input_layernorm = True
if self.config.fp8 or self.config.fp4:
self.self_attention.set_for_recompute_input_layernorm()

def can_recompute_pre_mlp_layernorm_for_cudagraph():
if (not self.is_moe_layer or CudaGraphScope.moe_router not in self.config.cuda_graph_scope
or self.config.cuda_graph_impl == 'local'):
# Not a MoE layer, or not capturing the router part.
return True
if (self.config.moe_shared_expert_intermediate_size is not None
and self.config.moe_shared_expert_overlap):
# If shared expert overlap is used, we cannot make the pre-mlp layernorm
# recomputation, because the shared expert takes the layernorm output as
# input, and it is outside of the CUDA graph scope.
logger.warning(
'pre_mlp_layernorm recompute is not supported with moe router '
'cudagraph + shared expert overlap. Disabling pre_mlp_layernorm '
'recompute.', )
return False
if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and (
self.config.moe_token_dispatcher_type == 'alltoall' or self.config.moe_latent_size):
# Only when capturing the preprocess part and using alltoall token
# dispatcher or latent MoE can we make the pre-mlp layernorm recomputation.
# Because in other cases the layernorm output returns directly as one of the
# outputs of the cudagraph, which will be allocated a static buffer, thus
# not able to be released.
return True
logger.warning(
'pre_mlp_layernorm recompute is only supported with moe router + '
'preprocess cudagraph will alltoall token dispatcher or latent MoE. '
'Disabling pre_mlp_layernorm recompute.', )
return False

if (not isinstance(self.pre_mlp_layernorm, IdentityOp)
and can_recompute_pre_mlp_layernorm_for_cudagraph()):
self.recompute_pre_mlp_layernorm = True
if self.config.fp8 or self.config.fp4:
if isinstance(self.mlp, MoELayer):
self.mlp.set_for_recompute_pre_mlp_layernorm()
else:
from megatron.core.extensions.transformer_engine import set_save_original_input

set_save_original_input(self.mlp.linear_fc1)
if 'mlp' in self.config.recompute_modules:
if not self.is_moe_layer:
self.recompute_mlp = True
if hasattr(self.config, 'fine_grained_activation_offloading'):
self.offload_attn_norm = (
self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules
and not isinstance(self.input_layernorm, IdentityOp))
self.offload_mlp_norm = (
self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules
and not isinstance(self.pre_mlp_layernorm, IdentityOp))

# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

self.bias_dropout_add_exec_handler is hardcoded to torch.enable_grad. In the original Megatron-Core implementation, this is typically conditional on the availability of nvfuser (using nullcontext if available). Hardcoding it may bypass performance optimizations or lead to unnecessary gradient tracking in certain fusion scenarios.


def forward(self, *args, **kwargs):
"""
Perform a forward pass through the transformer layer.

This method calls the core computation of a transformer layer, including
self-attention, cross-attention (if applicable), and feed-forward operations.
"""
hidden_states, context = self._forward_attention(*args, **kwargs)
# If padding_free is set, attention_mask does not exist.
mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs
mask = None
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
pad_size = 0
if mlp_padding_free and hidden_states.shape[1] > 1:
if enable_sp:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
Comment thread
Jintao-Huang marked this conversation as resolved.
mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t()
Comment on lines +241 to +250
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward method assumes attention_mask is always passed as a keyword argument. In Megatron-Core's TransformerBlock, layers are typically called with attention_mask as the second positional argument. This means kwargs.get('attention_mask') will be None, effectively disabling mlp_padding_free or causing a KeyError at line 252. Additionally, using the bitwise NOT operator ~ assumes a boolean mask; consider making this more robust for float masks.

Suggested change
hidden_states, context = self._forward_attention(*args, **kwargs)
mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs
mask = None
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
pad_size = 0
if mlp_padding_free and hidden_states.shape[1] > 1:
if enable_sp:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t()
hidden_states, context = self._forward_attention(*args, **kwargs)
attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None)
mlp_padding_free = self.config.mlp_padding_free and attention_mask is not None
mask = None
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
pad_size = 0
if mlp_padding_free and hidden_states.shape[1] > 1:
if enable_sp:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
mask = ((~attention_mask).sum(dim=(1, 2)) > 0).t()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The expression (~kwargs['attention_mask']) assumes that attention_mask is a boolean tensor. In many Megatron and HuggingFace configurations, attention_mask is provided as a float tensor (e.g., 0.0 for valid tokens and a large negative value for masked ones). Applying the bitwise NOT operator ~ to a float tensor will raise a TypeError. You should ensure the mask is boolean or use a comparison (e.g., kwargs['attention_mask'] == 0) to identify valid tokens.

hidden_states = hidden_states[mask][:, None]
if enable_sp:
tp_size = self.config.tensor_model_parallel_size
num_tokens = hidden_states.shape[0]
remainder = num_tokens % tp_size
if remainder != 0:
pad_size = tp_size - remainder
hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, 0, 0, pad_size))
hidden_states = scatter_to_sequence_parallel_region(hidden_states)
output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None))
if mask is not None:
if enable_sp:
output = gather_from_sequence_parallel_region(output, tensor_parallel_output_grad=False)
if pad_size > 0:
output = output[:-pad_size]
new_output = output.new_zeros((*mask.shape, output.shape[-1]))
new_output[mask] = output.squeeze(1)
output = new_output
if enable_sp:
output = scatter_to_sequence_parallel_region(output)
return output, context
Loading
Loading