Add PyTorch SDPA fallback for FlashAttention#5
Conversation
|
I did a second pass on this because I did not want to pretend the fallback is more than it is. What I think is happening: the repo is built around FlashAttention 3 not just for speed, but for shape, masking, and compile behavior. The PrefixLM path is doing something fairly specific. Prefix tokens need bidirectional attention within the prefix block. Response tokens need causal attention over the prefix plus prior response tokens. The original code does that with two FA3 passes and custom ops. That makes sense. It is the right thing to do if the target is Hopper with FA3. The AMD problem is more basic. On the MI300X machine, the repo-native path fails before model logic gets a chance to run, because I compared the fallback against a dense reference instead of only doing smoke tests. Results:
The bad news: this does not keep fullgraph |
Hi. I hit this while trying HRM-Text on an AMD MI300X box.
The Hugging Face checkpoint path worked.
sapientinc/HRM-Text-1Bloads and generates on ROCm withtorch 2.12.0+rocm7.2andtransformers 5.9.0.The repo-native path stopped much earlier. It could not import
models.layers, becauseflash_attn_interfaceis imported at module import time.My read is that the current code is doing exactly what it was built to do. It assumes the training/inference repo is running in the reference CUDA/Hopper setup with FlashAttention 3. That is a fair choice. PrefixLM attention here is not just normal causal attention. The implementation does a bidirectional prefix pass, then a causal response pass, and it wraps that in custom ops that are friendly to
torch.compile.So I do not think the original code is wrong. It is strict. It draws a hard line around the expected stack.
This PR adds a lower floor for machines that do not have FA3. If
flash_attn_interfaceis present, the old path is still used. If it is not present, the code falls back to PyTorch SDPA for:A slow but correct fallback is better than a hard stop, as long as we are honest that it is slow.
Validation so far:
torch 2.12.0+rocm7.2andtransformers 5.9.0.models.layersimports on MI300X withoutflash_attn_interfaceinstalled.2.4e-7and max grad diff around5e-6.3.6e-7.HierarchicalReasoningModelforward/backward smoke test runs on MI300X.python -m compileall -q modelspasses.The caveat is: fullgraph
torch.compileis not preserved for the fallback. The fallback uses Python loops and scalar reads from sequence metadata. Eager mode works. Non-fullgraph compile can run with graph breaks. But this is not the same thing as the original compiled FA3 path.I am still playing more with this on the MI300X box and will keep posting results as I find the edges. My guess is the right long-term answer is not this fallback. It is either a ROCm-capable FA-style kernel for this PrefixLM pattern, or a separate AMD inference path. This PR is just the small step that makes the repo usable enough to test those ideas.