fix(cuda): add F16-K + TURBO-V dispatch cases in flash attention (fixes #83)#84
Open
dentity007 wants to merge 1 commit intoTheTom:feature/turboquant-kv-cachefrom
Open
Conversation
The flash attention vector dispatcher was missing instance files, extern
declarations, and dispatch cases for F16-K + TURBO-V combinations. Any
of `-ctk f16 -ctv turbo{2,3,4}` aborted at fattn.cu:339 on sm_89 and
sm_121 because no dispatch case matched.
The kernel already treats TURBO types as unquantized alongside F16 and
BF16 (fattn-vec.cuh:79-80), so this is a dispatch-plumbing gap, not a
kernel limitation. Pure additive change.
CMakeLists.txt is updated so builds without GGML_CUDA_FA_ALL_QUANTS (the
default) also link the new template instances.
Validated on:
- RTX 4090 (sm_89) with Qwen3-30B-A3B Q4_K_M, FA_ALL_QUANTS=ON
- DGX Spark GB10 (sm_121) with Q4_K_M and UD-Q4_K_XL, both FA_ALL_QUANTS=ON and =OFF
Key results:
- Pre-fix crash reproduced at fattn.cu:339 (exit 134 SIGABRT, sm_89)
- Post-fix: all three new combos (f16/turbo{2,3,4}) run in llama-bench on sm_89
- sm_89 tg128: 215-218 t/s for new combos (vs 242 t/s for f16/f16 baseline)
- sm_121 tg128: 83.78-85.11 t/s for f16/turbo3 across model and build-config combinations
- Single-run PPL on sm_89 wikitext-2: 7.5281 vs 7.4950 f16/f16 baseline (delta 0.44%, within within-run error)
- KV context memory at 4K ctx on sm_89: 229 MiB vs 384 MiB f16/f16 baseline
Fixes TheTom#83
Author
|
Heads up for context. Ran additional sm_121 (DGX Spark GB10) validation overnight:
The Happy to paste the full tables as a follow-up comment or open a separate issue/PR for the |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #83. Adds missing F16-K + TURBO-V dispatch cases in the CUDA flash attention vector dispatcher so asymmetric
-ctk f16 -ctv turbo{2,3,4}configurations run on GPU instead of aborting atfattn.cu:339.Background
Issue #83 reports that running
llama-cliorllama-benchwith-ctk f16 -ctv turbo3(or turbo2, turbo4) aborts atGGML_ABORT("fatal error")on line 339 offattn.cu. Reproduced on both sm_89 (RTX 4090) and sm_121 (DGX Spark GB10) during PR #82 validation on Discussion ggml-org#20969.Root Cause
Same class as PR #82: missing dispatch entries for specific K/V type combinations.
The kernel itself already supports F16+TURBO structurally.
fattn-vec.cuhlines 79-80 treat TURBO types as unquantized alongside F16/BF16:So this is pure dispatch plumbing:
fattn-vec-instance-f16-turbo{2,3,4}_0.cufattn-vec.cuhfattn.cuCMakeLists.txtexplicit source list (theGGML_CUDA_FA_ALL_QUANTS=OFFdefault path)Changes
Pure additive, ~48 lines across 6 files. No modifications to existing logic.
template-instances/fattn-vec-instance-f16-turbo2_0.cutemplate-instances/fattn-vec-instance-f16-turbo3_0.cutemplate-instances/fattn-vec-instance-f16-turbo4_0.cuggml/src/ggml-cuda/fattn-vec.cuhggml/src/ggml-cuda/fattn.cuggml/src/ggml-cuda/CMakeLists.txtFollows the exact pattern of existing Q8_0+TURBO cases.
Scope
This PR covers F16-K + TURBO-V (the exact case from #83).
Out of scope for this PR (observations from validation):
-ctk f16,q8_0,turbo3 -ctv f16,q8_0,turbo3bench matrix aborts atfattn.cu:348on theturbo3/f16combination. Separate missing case with a similar fix pattern. Happy to file a separate issue or PR if useful.Testing
Base commit
59798f10doffeature/turboquant-kv-cache(post PR #82 merge). Same patch (shacca5e8f0...) applied on both architectures.Coverage matrix:
Both architectures and both models from the original #83 bug report covered. Both build configurations exercised (=ON uses the
file(GLOB)path; =OFF uses the explicit source list this PR updates).Pre-fix crash reproduced (sm_89)
Running the command from issue #83 against base commit
59798f10d:Same command on the patched build: model loads cleanly without abort. Throughput measurements from llama-bench shown below.
Build verified with both FA_ALL_QUANTS settings
The CMakeLists.txt update is required because:
GGML_CUDA_FA_ALL_QUANTS=ONusesfile(GLOB template-instances/fattn-vec*.cu)which picks up new files after cmake reconfigure.GGML_CUDA_FA_ALL_QUANTS=OFF(default) uses an explicit source list inggml-cuda/CMakeLists.txtthat must name each instance file.=ON validated on sm_89 (Iceland pod) and sm_121 (Spark). =OFF validated on sm_121 (Spark) with both Q4_K_M and UD-Q4_K_XL.
Without the CMakeLists.txt change, the =OFF path would fail at link because the explicit source list omits the three new .cu files. The same failure class (undefined symbol
ggml_cuda_flash_attn_ext_vec_case<...>) was observed on sm_89 during the initial post-patch build, when an incrementalcmake --builddid not re-run the GLOB and therefore did not pick up the new files; forcing a cmake reconfigure resolved it. That observation confirms the failure mode the CMakeLists.txt update prevents is real.sm_89 validation on RTX 4090 — the three newly-enabled combos
Targeted single-combo runs,
-ngl 99 -fa 1 -b 2048 -ub 2048 -p 0 -n 128 -pg 8192,128:All three run at GPU speeds (contrast: PR #82 CPU-fallback symptom was single-digit t/s). Throughput numbers are in the same band, consistent with the three variants sharing the same unquantized-V kernel branch.
Note: the f16/turbo3 tg128 value here (216.52 ± 1.67) and in the 3x3 matrix below (215.54 ± 0.91) come from two independent bench invocations on the same build. They agree within their error bars; difference is run-to-run variance.
sm_89 validation on RTX 4090 (3x3 matrix at 8K and 32K context)
For context, the symmetric baseline combos (not touched by this PR) on the same build:
f16/turbo3throughput is within ~11% of thef16/f16baseline and within ~5% of thef16/q8_0baseline at both 8K and 32K contexts.sm_121 validation on DGX Spark GB10
Same patch, same base commit as sm_89. Q4_K_M tested with both
FA_ALL_QUANTSsettings; UD-Q4_K_XL tested with =OFF. Onlyf16/turbo3exercised on sm_121. The fix pattern is symmetric across turbo2, turbo3, turbo4 (identical structural additions in each of the 6 files), and all three were exercised individually on sm_89 (see the three-combos table above). The three variants route through the same unquantized-V branch of the flash-attention kernel.-ngl 99 -fa 1 -b 2048 -ub 2048 -p 0 -n 128 -pg 8192,128:=ON vs =OFF on Q4_K_M: 0.2% difference on both metrics, consistent with same code path + run-to-run variance. No abort. No CPU fallback (inferred from throughput: GB10 CPU-path on similar workloads reported ~5 t/s in the original PR #82 bug; these are ~80 t/s, GPU-consistent).
Numerical equivalence (llama-perplexity on wikitext-2 test)
Context 4096, 75 chunks, same model, same seed conditions:
Delta: +0.0331 PPL (+0.44%). The ± values are within-run bootstrap estimates from a single perplexity run per config, not across-run variance. The delta is smaller than the within-run error estimate of either measurement, but confirming this with multiple runs of each config would strengthen the claim. One-run result: no meaningful quality shift attributable to the new dispatch path.
KV cache memory footprint
Reported by
llama_memory_breakdown_printat 4K context:40.4% reduction in KV context memory for the
f16/turbo3asymmetric configuration. Expected behavior, confirms the turbo3 V path is actually being taken.Reproduction
Same command as the issue #83 body, using llama-bench: