diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index ec87ddb..0481e88 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -17,6 +17,7 @@ from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from megatron.core.utils import deprecate_inference_params, is_fa_min_version from packaging import version +from transformers.utils import is_torch_npu_available from typing import Optional, Tuple, Union from mcore_bridge.bridge import GPTBridge @@ -58,6 +59,17 @@ logger = get_logger() +def resolve_gdn_attention_mask(kwargs) -> Optional[torch.Tensor]: + if is_torch_npu_available(): + attention_mask = kwargs.get('attention_mask_2d') + if attention_mask is not None: + return attention_mask.to(torch.bool) + attention_mask = kwargs.get('attention_mask') + if attention_mask is None: + return None + return (~attention_mask).sum(dim=(1, 2)) > 0 + + class Qwen3NextRMSNorm(torch.nn.Module): """ Zero-Centered RMSNorm for Qwen3-Next. @@ -485,9 +497,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = kwargs.get('attention_mask') - if attention_mask is not None: - attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 + attention_mask = resolve_hf_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5.py b/src/mcore_bridge/model/mm_gpts/qwen3_5.py index 91472c3..801c55c 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5.py @@ -10,7 +10,7 @@ from mcore_bridge.utils import get_env_args from ..constant import ModelType -from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader +from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, resolve_gdn_attention_mask from ..register import ModelMeta, register_model from .utils import HuggingFaceVit @@ -52,9 +52,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = kwargs.get('attention_mask') - if attention_mask is not None: - attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 + attention_mask = resolve_gdn_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None]