Skip to content

[megatron] support mlp_padding_free & sp; refactor TransformerLayer#62

Merged
Jintao-Huang merged 3 commits intomodelscope:mainfrom
Jintao-Huang:mlp_padding_free_support_sp
May 5, 2026
Merged

[megatron] support mlp_padding_free & sp; refactor TransformerLayer#62
Jintao-Huang merged 3 commits intomodelscope:mainfrom
Jintao-Huang:mlp_padding_free_support_sp

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

@Jintao-Huang Jintao-Huang changed the title [megatron] support mlp_padding_free & sp [megatron] support mlp_padding_free & sp; refactor TransformerLayer May 5, 2026
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CustomTransformerLayer to centralize transformer logic and replaces previous monkey-patching, while also refactoring model loaders and registration for increased flexibility. Review feedback identifies a potential TypeError in the CustomTransformerLayer constructor and suggests more robust attention_mask handling in the forward method to account for positional arguments. Additionally, improvements were recommended for layer numbering consistency in specific MLP modules and for the accuracy of warning logs.

Comment thread src/mcore_bridge/model/modules/transformer_layer.py
Comment on lines +244 to +252
hidden_states, context = self._forward_attention(*args, **kwargs)
mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs
mask = None
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
pad_size = 0
if mlp_padding_free and hidden_states.shape[1] > 1:
if enable_sp:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward method assumes attention_mask is always passed as a keyword argument. In Megatron-Core's TransformerBlock, layers are typically called with attention_mask as the second positional argument. This means kwargs.get('attention_mask') will be None, effectively disabling mlp_padding_free or causing a KeyError at line 252. Additionally, using the bitwise NOT operator ~ assumes a boolean mask; consider making this more robust for float masks.

Suggested change
hidden_states, context = self._forward_attention(*args, **kwargs)
mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs
mask = None
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
pad_size = 0
if mlp_padding_free and hidden_states.shape[1] > 1:
if enable_sp:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t()
hidden_states, context = self._forward_attention(*args, **kwargs)
attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None)
mlp_padding_free = self.config.mlp_padding_free and attention_mask is not None
mask = None
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
pad_size = 0
if mlp_padding_free and hidden_states.shape[1] > 1:
if enable_sp:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
mask = ((~attention_mask).sum(dim=(1, 2)) > 0).t()

Comment thread src/mcore_bridge/model/modules/transformer_layer.py Outdated
Comment thread src/mcore_bridge/model/modules/transformer_layer.py Outdated
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CustomTransformerLayer to replace the previous monkey-patching of Megatron-Core's TransformerLayer, providing a more structured way to handle custom forward logic such as padding-free MLPs. It updates model loaders for GLM4 and MinimaxM2, refactors the model registration process to inject this custom layer, and moves utility functions like patch_deepcopy to a central location. However, the review identified several high-severity issues in the new CustomTransformerLayer: the constructor fragilely bypasses TransformerLayer's initialization, the sequence parallel implementation for padding-free logic risks OOM by gathering full activations on all ranks, and the attention mask handling assumes a boolean type which may cause runtime errors with float masks. Additionally, hardcoding the gradient execution handler may bypass performance optimizations.

Comment thread src/mcore_bridge/model/modules/transformer_layer.py
Comment thread src/mcore_bridge/model/modules/transformer_layer.py
if mlp_padding_free and hidden_states.shape[1] > 1:
if enable_sp:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The expression (~kwargs['attention_mask']) assumes that attention_mask is a boolean tensor. In many Megatron and HuggingFace configurations, attention_mask is provided as a float tensor (e.g., 0.0 for valid tokens and a large negative value for masked ones). Applying the bitwise NOT operator ~ to a float tensor will raise a TypeError. You should ensure the mask is boolean or use a comparison (e.g., kwargs['attention_mask'] == 0) to identify valid tokens.

# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

self.bias_dropout_add_exec_handler is hardcoded to torch.enable_grad. In the original Megatron-Core implementation, this is typically conditional on the availability of nvfuser (using nullcontext if available). Hardcoding it may bypass performance optimizations or lead to unnecessary gradient tracking in certain fusion scenarios.

@Jintao-Huang Jintao-Huang merged commit 042439c into modelscope:main May 5, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant