From ae98b8a21ffeb9d42cae0b29dc52a9ad4028c0df Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Mon, 20 Apr 2026 22:19:04 +0800 Subject: [PATCH 1/5] qwen3_5_npu --- src/mcore_bridge/model/gpts/qwen3_next.py | 18 +++++++++++++++--- src/mcore_bridge/model/mm_gpts/qwen3_5.py | 6 ++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index ec87ddb..2844b6b 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -18,6 +18,7 @@ from megatron.core.utils import deprecate_inference_params, is_fa_min_version from packaging import version from typing import Optional, Tuple, Union +from transformers.utils import is_torch_npu_available from mcore_bridge.bridge import GPTBridge from mcore_bridge.config import ModelConfig @@ -58,6 +59,19 @@ logger = get_logger() +def _get_hf_attention_mask(kwargs) -> Optional[torch.Tensor]: + if is_torch_npu_available(): + attention_mask = kwargs.get('attention_mask_2d') + if attention_mask is not None: + return attention_mask.to(torch.bool) + attention_mask = kwargs.get('attention_mask') + if attention_mask is None: + return None + if attention_mask.ndim == 4: + return (~attention_mask).sum(dim=(1, 2)) > 0 + return attention_mask.to(torch.bool) + + class Qwen3NextRMSNorm(torch.nn.Module): """ Zero-Centered RMSNorm for Qwen3-Next. @@ -485,9 +499,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = kwargs.get('attention_mask') - if attention_mask is not None: - attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 + attention_mask = _get_hf_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5.py b/src/mcore_bridge/model/mm_gpts/qwen3_5.py index 91472c3..9fc212d 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5.py @@ -10,7 +10,7 @@ from mcore_bridge.utils import get_env_args from ..constant import ModelType -from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader +from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, _get_hf_attention_mask from ..register import ModelMeta, register_model from .utils import HuggingFaceVit @@ -52,9 +52,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = kwargs.get('attention_mask') - if attention_mask is not None: - attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 + attention_mask = _get_hf_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] From 7c034c3151f00d3cfebad288262a1de82de4d4fa Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 23 Apr 2026 15:12:26 +0800 Subject: [PATCH 2/5] update --- src/mcore_bridge/model/gpts/qwen3_next.py | 4 ++-- src/mcore_bridge/model/mm_gpts/qwen3_5.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index 2844b6b..2205b03 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -59,7 +59,7 @@ logger = get_logger() -def _get_hf_attention_mask(kwargs) -> Optional[torch.Tensor]: +def resolve_hf_attention_mask(kwargs) -> Optional[torch.Tensor]: if is_torch_npu_available(): attention_mask = kwargs.get('attention_mask_2d') if attention_mask is not None: @@ -499,7 +499,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = _get_hf_attention_mask(kwargs) + attention_mask = resolve_hf_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5.py b/src/mcore_bridge/model/mm_gpts/qwen3_5.py index 9fc212d..3805714 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5.py @@ -10,7 +10,7 @@ from mcore_bridge.utils import get_env_args from ..constant import ModelType -from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, _get_hf_attention_mask +from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, resolve_hf_attention_mask from ..register import ModelMeta, register_model from .utils import HuggingFaceVit @@ -52,7 +52,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = _get_hf_attention_mask(kwargs) + attention_mask = resolve_hf_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] From f859378f11d3a958f2ac1eb2f709892c9c38feb1 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 23 Apr 2026 17:08:58 +0800 Subject: [PATCH 3/5] update --- src/mcore_bridge/model/gpts/qwen3_next.py | 2 +- src/mcore_bridge/model/mm_gpts/qwen3_5.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index 2205b03..37b4060 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -59,7 +59,7 @@ logger = get_logger() -def resolve_hf_attention_mask(kwargs) -> Optional[torch.Tensor]: +def resolve_gdn_attention_mask(kwargs) -> Optional[torch.Tensor]: if is_torch_npu_available(): attention_mask = kwargs.get('attention_mask_2d') if attention_mask is not None: diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5.py b/src/mcore_bridge/model/mm_gpts/qwen3_5.py index 3805714..801c55c 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5.py @@ -10,7 +10,7 @@ from mcore_bridge.utils import get_env_args from ..constant import ModelType -from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, resolve_hf_attention_mask +from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, resolve_gdn_attention_mask from ..register import ModelMeta, register_model from .utils import HuggingFaceVit @@ -52,7 +52,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = resolve_hf_attention_mask(kwargs) + attention_mask = resolve_gdn_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] From 4e36b4842b92808fcf490f944f8ebb4af39b23c9 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 23 Apr 2026 17:53:46 +0800 Subject: [PATCH 4/5] update --- src/mcore_bridge/model/gpts/qwen3_next.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index 37b4060..cf61909 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -67,9 +67,7 @@ def resolve_gdn_attention_mask(kwargs) -> Optional[torch.Tensor]: attention_mask = kwargs.get('attention_mask') if attention_mask is None: return None - if attention_mask.ndim == 4: - return (~attention_mask).sum(dim=(1, 2)) > 0 - return attention_mask.to(torch.bool) + return (~attention_mask).sum(dim=(1, 2)) > 0 class Qwen3NextRMSNorm(torch.nn.Module): From 130f7a9dbaadf7eb571308dc9ed87aaa12c71449 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Fri, 24 Apr 2026 11:18:59 +0800 Subject: [PATCH 5/5] fix --- src/mcore_bridge/model/gpts/qwen3_next.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index cf61909..0481e88 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -17,8 +17,8 @@ from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from megatron.core.utils import deprecate_inference_params, is_fa_min_version from packaging import version -from typing import Optional, Tuple, Union from transformers.utils import is_torch_npu_available +from typing import Optional, Tuple, Union from mcore_bridge.bridge import GPTBridge from mcore_bridge.config import ModelConfig