From c104a368881f05bb3325c6cacb1c89f673cf29bf Mon Sep 17 00:00:00 2001 From: dikanquit Date: Fri, 29 May 2026 11:12:18 +0200 Subject: [PATCH] sdpa is_causal: offset tril by seq_k - seq_q --- tinygrad/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cf202bf873909..ca6d253b6f94a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1362,7 +1362,7 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tenso # handle attention mask if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") - attn_mask = qk.const_like(1).cast(dtypes.bool).tril() + attn_mask = qk.const_like(1).cast(dtypes.bool).tril(int(key.shape[-2])-int(q.shape[-2])) if attn_mask is not None: if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) qk = qk + attn_mask