diff --git a/models/flash_attention_prefixlm_v2.py b/models/flash_attention_prefixlm_v2.py index b7856ee..1ff7c98 100644 --- a/models/flash_attention_prefixlm_v2.py +++ b/models/flash_attention_prefixlm_v2.py @@ -120,12 +120,15 @@ def flash_attn_varlen_prefixlm_compileop( max_seqlen_q=max_seqlen_prefix_int, max_seqlen_k=max_seqlen_prefix_int, causal=is_causal) # Fwd pass 2 (causal) - _, softmax_lse_causal = _custom_flash_attn_forward( - out_=out, q=q, k=k, v=v, - cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, - seqused_q=causal_lens, - max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, - causal=True) + if max_seqlen_causal_int > 0: + _, softmax_lse_causal = _custom_flash_attn_forward( + out_=out, q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, + seqused_q=causal_lens, + max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, + causal=True) + else: + softmax_lse_causal = torch.empty_like(softmax_lse_bidir) out[total_seqlen_int:] = 0 # Zero padding return out, softmax_lse_bidir, softmax_lse_causal @@ -181,7 +184,6 @@ def flash_attn_varlen_prefixlm_bwd_compileop( # Buffers dq = torch.empty_like(q) dk1, dv1 = torch.zeros_like(k), torch.zeros_like(v) # Zero-fill in advance - dk2, dv2 = torch.empty_like(k), torch.empty_like(v) # Bwd pass 1 (bidirectional) _flash_attn_backward( dout=dout, q=q, k=k, v=v, out=out, @@ -194,22 +196,26 @@ def flash_attn_varlen_prefixlm_bwd_compileop( dv=dv1, is_causal=is_causal) # Bwd pass 2 (causal) - _flash_attn_backward( - dout=dout, q=q, k=k, v=v, out=out, - softmax_lse=softmax_lse_causal, - cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, - sequed_q=causal_lens, - dq=dq, - dk=dk2, - dv=dv2, - is_causal=True) + if max_seqlen_causal_int > 0: + dk2, dv2 = torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout=dout, q=q, k=k, v=v, out=out, + softmax_lse=softmax_lse_causal, + cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int, + sequed_q=causal_lens, + dq=dq, + dk=dk2, + dv=dv2, + is_causal=True) + dk2[total_seqlen_int:] = 0 + dv2[total_seqlen_int:] = 0 + dk1.add_(dk2) + dv1.add_(dv2) # Zero padding grads dq[total_seqlen_int:] = 0 - dk2[total_seqlen_int:] = 0 - dv2[total_seqlen_int:] = 0 - return dq, dk1.add_(dk2), dv1.add_(dv2) + return dq, dk1, dv1 @torch.library.register_fake("flash_attn::flash_attn_varlen_prefixlm_bwd_compileop")