Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions models/flash_attention_prefixlm_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,15 @@ def flash_attn_varlen_prefixlm_compileop(
max_seqlen_q=max_seqlen_prefix_int, max_seqlen_k=max_seqlen_prefix_int,
causal=is_causal)
# Fwd pass 2 (causal)
_, softmax_lse_causal = _custom_flash_attn_forward(
out_=out, q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens,
seqused_q=causal_lens,
max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int,
causal=True)
if max_seqlen_causal_int > 0:
_, softmax_lse_causal = _custom_flash_attn_forward(
out_=out, q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens,
seqused_q=causal_lens,
max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int,
causal=True)
else:
softmax_lse_causal = torch.empty_like(softmax_lse_bidir)

out[total_seqlen_int:] = 0 # Zero padding
return out, softmax_lse_bidir, softmax_lse_causal
Expand Down Expand Up @@ -181,7 +184,6 @@ def flash_attn_varlen_prefixlm_bwd_compileop(
# Buffers
dq = torch.empty_like(q)
dk1, dv1 = torch.zeros_like(k), torch.zeros_like(v) # Zero-fill in advance
dk2, dv2 = torch.empty_like(k), torch.empty_like(v)
# Bwd pass 1 (bidirectional)
_flash_attn_backward(
dout=dout, q=q, k=k, v=v, out=out,
Expand All @@ -194,22 +196,26 @@ def flash_attn_varlen_prefixlm_bwd_compileop(
dv=dv1,
is_causal=is_causal)
# Bwd pass 2 (causal)
_flash_attn_backward(
dout=dout, q=q, k=k, v=v, out=out,
softmax_lse=softmax_lse_causal,
cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int,
sequed_q=causal_lens,
dq=dq,
dk=dk2,
dv=dv2,
is_causal=True)
if max_seqlen_causal_int > 0:
dk2, dv2 = torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout=dout, q=q, k=k, v=v, out=out,
softmax_lse=softmax_lse_causal,
cu_seqlens_q=cu_seqlens_shifted, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen_causal_int, max_seqlen_k=max_seqlen_all_int,
sequed_q=causal_lens,
dq=dq,
dk=dk2,
dv=dv2,
is_causal=True)
dk2[total_seqlen_int:] = 0
dv2[total_seqlen_int:] = 0
dk1.add_(dk2)
dv1.add_(dv2)

# Zero padding grads
dq[total_seqlen_int:] = 0
dk2[total_seqlen_int:] = 0
dv2[total_seqlen_int:] = 0
return dq, dk1.add_(dk2), dv1.add_(dv2)
return dq, dk1, dv1


@torch.library.register_fake("flash_attn::flash_attn_varlen_prefixlm_bwd_compileop")
Expand Down