Skip to content

Add PyTorch SDPA fallback for FlashAttention#5

Open
beiyonder wants to merge 1 commit into
sapientinc:mainfrom
beiyonder:rocm-sdpa-fallback
Open

Add PyTorch SDPA fallback for FlashAttention#5
beiyonder wants to merge 1 commit into
sapientinc:mainfrom
beiyonder:rocm-sdpa-fallback

Conversation

@beiyonder

@beiyonder beiyonder commented May 21, 2026

Copy link
Copy Markdown

Hi. I hit this while trying HRM-Text on an AMD MI300X box.

The Hugging Face checkpoint path worked. sapientinc/HRM-Text-1B loads and generates on ROCm with torch 2.12.0+rocm7.2 and transformers 5.9.0.

The repo-native path stopped much earlier. It could not import models.layers, because flash_attn_interface is imported at module import time.

My read is that the current code is doing exactly what it was built to do. It assumes the training/inference repo is running in the reference CUDA/Hopper setup with FlashAttention 3. That is a fair choice. PrefixLM attention here is not just normal causal attention. The implementation does a bidirectional prefix pass, then a causal response pass, and it wraps that in custom ops that are friendly to torch.compile.

So I do not think the original code is wrong. It is strict. It draws a hard line around the expected stack.

This PR adds a lower floor for machines that do not have FA3. If flash_attn_interface is present, the old path is still used. If it is not present, the code falls back to PyTorch SDPA for:

  • varlen PrefixLM attention
  • KV-cache attention
  • import-time compatibility on ROCm and other non-FA3 boxes

A slow but correct fallback is better than a hard stop, as long as we are honest that it is slow.

Validation so far:

  • HF checkpoint generated on AMD Instinct MI300X VF with torch 2.12.0+rocm7.2 and transformers 5.9.0.
  • models.layers imports on MI300X without flash_attn_interface installed.
  • Fallback PrefixLM attention matches a dense reference in fp32 with max output diff around 2.4e-7 and max grad diff around 5e-6.
  • bf16 differences are at normal bf16 scale.
  • KV-cache fallback matches a dense reference in fp32 with max output diff around 3.6e-7.
  • A tiny HierarchicalReasoningModel forward/backward smoke test runs on MI300X.
  • python -m compileall -q models passes.

The caveat is: fullgraph torch.compile is not preserved for the fallback. The fallback uses Python loops and scalar reads from sequence metadata. Eager mode works. Non-fullgraph compile can run with graph breaks. But this is not the same thing as the original compiled FA3 path.

I am still playing more with this on the MI300X box and will keep posting results as I find the edges. My guess is the right long-term answer is not this fallback. It is either a ROCm-capable FA-style kernel for this PrefixLM pattern, or a separate AMD inference path. This PR is just the small step that makes the repo usable enough to test those ideas.

@beiyonder

beiyonder commented May 21, 2026

Copy link
Copy Markdown
Author

I did a second pass on this because I did not want to pretend the fallback is more than it is.

What I think is happening: the repo is built around FlashAttention 3 not just for speed, but for shape, masking, and compile behavior. The PrefixLM path is doing something fairly specific. Prefix tokens need bidirectional attention within the prefix block. Response tokens need causal attention over the prefix plus prior response tokens. The original code does that with two FA3 passes and custom ops.

That makes sense. It is the right thing to do if the target is Hopper with FA3.

The AMD problem is more basic. On the MI300X machine, the repo-native path fails before model logic gets a chance to run, because flash_attn_interface is a hard import. So the fallback here is mostly about getting past that wall. It makes the code import. It lets small eager checks run. It gives us a correctness baseline.

I compared the fallback against a dense reference instead of only doing smoke tests.

Results:

  • Mixed PrefixLM fp32: max output diff about 2.4e-7, max grad diff about 4.9e-6.
  • Pure causal fp32: max output diff about 2.4e-7, max grad diff about 4.5e-6.
  • Single sequence fp32: max output diff about 2.4e-7, max grad diff about 1.7e-6.
  • bf16 diffs are at bf16 scale, usually 0.0039 to 0.0078 for outputs.
  • KV-cache fallback matches dense reference with fp32 max output diff about 3.6e-7.
  • Cache writes land in the expected positions.
  • Tiny HRM forward/backward passes on AMD Instinct MI300X VF with torch 2.12.0+rocm7.2.
  • compileall passes for models.

The bad news: this does not keep fullgraph torch.compile semantics for the fallback. The fallback has Python loops and .item() style scalar extraction from sequence metadata. That means fullgraph compile with dynamic seq metadata fails. Non-fullgraph compile can run with graph breaks. Eager works.
So my current view is: this PR is a safe eager correctness fallback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant