Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/mcore_bridge/model/gpts/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.utils import deprecate_inference_params, is_fa_min_version
from packaging import version
from transformers.utils import is_torch_npu_available
from typing import Optional, Tuple, Union

from mcore_bridge.bridge import GPTBridge
Expand Down Expand Up @@ -58,6 +59,17 @@
logger = get_logger()


def resolve_gdn_attention_mask(kwargs) -> Optional[torch.Tensor]:
if is_torch_npu_available():
attention_mask = kwargs.get('attention_mask_2d')
if attention_mask is not None:
return attention_mask.to(torch.bool)
attention_mask = kwargs.get('attention_mask')
if attention_mask is None:
return None
return (~attention_mask).sum(dim=(1, 2)) > 0


class Qwen3NextRMSNorm(torch.nn.Module):
"""
Zero-Centered RMSNorm for Qwen3-Next.
Expand Down Expand Up @@ -485,9 +497,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
hidden_states = new_hidden_states
else:
hidden_states = hidden_states.transpose(0, 1)
attention_mask = kwargs.get('attention_mask')
if attention_mask is not None:
attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0
attention_mask = resolve_hf_attention_mask(kwargs)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function call resolve_hf_attention_mask(kwargs) appears to be a typo. The function defined earlier in this file (at line 62) is named resolve_gdn_attention_mask. Using the incorrect name will result in a NameError at runtime.

Suggested change
attention_mask = resolve_hf_attention_mask(kwargs)
attention_mask = resolve_gdn_attention_mask(kwargs)

Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolve_hf_attention_mask is referenced here but is not defined anywhere in this module (or elsewhere in the repo). This will raise a NameError at runtime. Use the newly introduced resolve_gdn_attention_mask(kwargs) here (or add/rename the helper consistently if a different resolver is intended).

Suggested change
attention_mask = resolve_hf_attention_mask(kwargs)
attention_mask = resolve_gdn_attention_mask(kwargs)

Copilot uses AI. Check for mistakes.
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
if thd_format:
res = res[attention_mask][:, None]
Expand Down
6 changes: 2 additions & 4 deletions src/mcore_bridge/model/mm_gpts/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mcore_bridge.utils import get_env_args

from ..constant import ModelType
from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader
from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, resolve_gdn_attention_mask
from ..register import ModelMeta, register_model
from .utils import HuggingFaceVit

Expand Down Expand Up @@ -52,9 +52,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
hidden_states = new_hidden_states
else:
hidden_states = hidden_states.transpose(0, 1)
attention_mask = kwargs.get('attention_mask')
if attention_mask is not None:
attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0
attention_mask = resolve_gdn_attention_mask(kwargs)
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
if thd_format:
res = res[attention_mask][:, None]
Expand Down
Loading