diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cf202bf873909..ca6d253b6f94a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1362,7 +1362,7 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tenso # handle attention mask if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") - attn_mask = qk.const_like(1).cast(dtypes.bool).tril() + attn_mask = qk.const_like(1).cast(dtypes.bool).tril(int(key.shape[-2])-int(q.shape[-2])) if attn_mask is not None: if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) qk = qk + attn_mask