Skip to content

Fix Eagle3 FA cached backward semantics and keep FA2 fallback backend#79

Merged
zhyncs merged 5 commits intolightseekorg:mainfrom
uygnef:fix/eagle3_fa
Apr 16, 2026
Merged

Fix Eagle3 FA cached backward semantics and keep FA2 fallback backend#79
zhyncs merged 5 commits intolightseekorg:mainfrom
uygnef:fix/eagle3_fa

Conversation

@uygnef
Copy link
Copy Markdown
Collaborator

@uygnef uygnef commented Apr 16, 2026

Summary

This PR fixes the Eagle3 fa cached-attention backward path.

The same gradient issue can also be avoided by
LlamaFlashAttentionMasked, but that path depends on the masked / cute FA
stack.

This PR fixes the standard fa path instead, so we have a correct backend
that only depends on standard FA2 and can serve as a more compatible
backup backend, especially on older devices or environments where the
masked FA path is unavailable.

Background

For Eagle3 cached attention, the final output is a merge of:

  • block0 causal attention
  • suffix singleton cached blocks

The old fa path handled block0 and the external merge separately, which
made block0 backward inconsistent with the final merged normalization.

As a result, cached-path q/k gradients could diverge from
flex_attention / masked-attention behavior and contribute to abnormal
gradient spikes.

LlamaFlashAttentionMasked does not have this issue because it expresses
the attention pattern inside a masked FA formulation, but it has stricter
backend/runtime requirements.

What this PR changes

In torchspec/models/draft/llama3_eagle.py:

  • fix the cached-merge backward semantics for fa
  • keep block0 on standard flash-attention
  • reuse standard flash-attn backward for block0 with merged
    combined_out / combined_lse
  • keep suffix block gradients analytic and explicit
  • support padded batches via varlen flash-attn for block0
  • simplify the main fa path around the corrected implementation

Why this is useful

This gives us a corrected fa backend that:

  • matches the intended Eagle3 cached-merge backward semantics
  • aligns with flex_attention
  • does not require the masked / cute FA path
  • only depends on standard FA2
  • can be used as a more compatible backup backend on older devices or less
    specialized environments

Validation

Historical dump replay

On a historical spike dump:

  • dump:
    • grad_norm = 1422.276855
    • total_loss = 14.712839
  • fixed fa replay:
    • grad_norm = 22.8646
    • weighted_loss = 14.633358
  • flex_attention replay:
    • grad_norm = 22.7233
    • weighted_loss = 14.633938

So after the fix:

  • fa and flex_attention are essentially aligned
  • the previous large-gradient behavior is no longer reproduced

Padded batch benchmark

Right-padded batch, batch=4, mixed valid lengths.

max seq valid lengths fa time flex time fa peak flex peak
4096 [4096, 3584, 2560, 1536] 0.155s 0.122s 6.33 GiB 5.22 GiB
8192 [8192, 7168, 5120, 3072] 0.248s 0.317s 14.02 GiB 13.35 GiB

Numerical alignment stayed good:

  • losses matched
  • outputs stayed close
  • parameter gradient relative L2 error stayed small
image image

Tests

Updated / added checks in tests/test_flex_attention.py:

  • cached-path gradients match flex_attention
  • forward behavior matches expected outputs
  • padded batch cases are numerically aligned with flex_attention

@uygnef uygnef changed the title Fix/eagle3 fa Fix Eagle3 FA cached backward semantics and keep FA2 fallback backend Apr 16, 2026
uygnef added 5 commits April 16, 2026 16:31
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
@zhyncs zhyncs merged commit 6dda2bf into lightseekorg:main Apr 16, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants