diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index 08cce14..5380434 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -497,7 +497,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): hidden_states = new_hidden_states else: hidden_states = hidden_states.transpose(0, 1) - attention_mask = resolve_hf_attention_mask(kwargs) + attention_mask = resolve_gdn_attention_mask(kwargs) res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None]