From 35a92efdc1b2c383616a4a6a05d5f46b9f63bb1a Mon Sep 17 00:00:00 2001 From: Zane12518 <182461761+Zane12518@users.noreply.github.com> Date: Thu, 28 May 2026 02:55:28 +0800 Subject: [PATCH 1/2] Handle zero-causal PrefixLM FA3 pass as no-op --- models/flash_attention_prefixlm_v2.py | 40 +++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/models/flash_attention_prefixlm_v2.py b/models/flash_attention_prefixlm_v2.py index b7856ee..689ff04 100644 --- a/models/flash_attention_prefixlm_v2.py +++ b/models/flash_attention_prefixlm_v2.py @@ -120,12 +120,15 @@ def flash_attn_varlen_prefixlm_compileop( max_seqlen_q=max_seqlen_prefix_int, max_seqlen_k=max_seqlen_prefix_int, causal=is_causal) # Fwd pass 2 (causal) - _, softmax_lse_causal = _custom_flash_attn_forward( - out_=out, q=q, k=k, v=v, - cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, - seqused_q=causal_lens, - max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, - causal=True) + if max_seqlen_causal_int > 0: + _, softmax_lse_causal = _custom_flash_attn_forward( + out_=out, q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, + seqused_q=causal_lens, + max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, + causal=True) + else: + softmax_lse_causal = torch.empty_like(softmax_lse_bidir) out[total_seqlen_int:] = 0 # Zero padding return out, softmax_lse_bidir, softmax_lse_causal @@ -181,7 +184,6 @@ def flash_attn_varlen_prefixlm_bwd_compileop( # Buffers dq = torch.empty_like(q) dk1, dv1 = torch.zeros_like(k), torch.zeros_like(v) # Zero-fill in advance - dk2, dv2 = torch.empty_like(k), torch.empty_like(v) # Bwd pass 1 (bidirectional) _flash_attn_backward( dout=dout, q=q, k=k, v=v, out=out, @@ -194,16 +196,20 @@ def flash_attn_varlen_prefixlm_bwd_compileop( dv=dv1, is_causal=is_causal) # Bwd pass 2 (causal) - _flash_attn_backward( - dout=dout, q=q, k=k, v=v, out=out, - softmax_lse=softmax_lse_causal, - cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, - sequed_q=causal_lens, - dq=dq, - dk=dk2, - dv=dv2, - is_causal=True) + if max_seqlen_causal_int > 0: + dk2, dv2 = torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout=dout, q=q, k=k, v=v, out=out, + softmax_lse=softmax_lse_causal, + cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, + sequed_q=causal_lens, + dq=dq, + dk=dk2, + dv=dv2, + is_causal=True) + else: + dk2, dv2 = torch.zeros_like(k), torch.zeros_like(v) # Zero padding grads dq[total_seqlen_int:] = 0 From 428dfc5171d1751c35c02e172069cd0e5b3d4679 Mon Sep 17 00:00:00 2001 From: Zane12518 <182461761+Zane12518@users.noreply.github.com> Date: Fri, 29 May 2026 22:24:44 +0800 Subject: [PATCH 2/2] Avoid zero tensors for zero-causal FA3 no-op --- models/flash_attention_prefixlm_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models/flash_attention_prefixlm_v2.py b/models/flash_attention_prefixlm_v2.py index 689ff04..1ff7c98 100644 --- a/models/flash_attention_prefixlm_v2.py +++ b/models/flash_attention_prefixlm_v2.py @@ -208,14 +208,14 @@ def flash_attn_varlen_prefixlm_bwd_compileop( dk=dk2, dv=dv2, is_causal=True) - else: - dk2, dv2 = torch.zeros_like(k), torch.zeros_like(v) + dk2[total_seqlen_int:] = 0 + dv2[total_seqlen_int:] = 0 + dk1.add_(dk2) + dv1.add_(dv2) # Zero padding grads dq[total_seqlen_int:] = 0 - dk2[total_seqlen_int:] = 0 - dv2[total_seqlen_int:] = 0 - return dq, dk1.add_(dk2), dv1.add_(dv2) + return dq, dk1, dv1 @torch.library.register_fake("flash_attn::flash_attn_varlen_prefixlm_bwd_compileop")