fix qwen3_next#58
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the forward method in qwen3_next.py to utilize resolve_gdn_attention_mask instead of resolve_hf_attention_mask. A review comment identifies potential logic errors and runtime risks within the resolve_gdn_attention_mask implementation, specifically regarding type safety when applying bitwise operators to integer tensors and potential dimensionality mismatches that could cause runtime errors.
| else: | ||
| hidden_states = hidden_states.transpose(0, 1) | ||
| attention_mask = resolve_hf_attention_mask(kwargs) | ||
| attention_mask = resolve_gdn_attention_mask(kwargs) |
There was a problem hiding this comment.
The function resolve_gdn_attention_mask (called here and defined at line 62) has a potential logic error and runtime risk:
- Type Safety: At line 70, it uses the bitwise NOT operator
~onattention_mask. If the mask is an integer tensor (common for HuggingFace masks),~1results in-2, which will cause thesum(...) > 0check to behave unexpectedly. It should be explicitly cast to boolean:(~attention_mask.to(torch.bool)). - Dimensionality: The function assumes a 4D input by using
dim=(1, 2). Ifkwargs.get('attention_mask')returns a 2D tensor (standard HuggingFace format), this will raise anIndexError.
Since this function is now being used to resolve the attention mask, these issues should be addressed in its implementation to ensure robustness.
No description provided.