Skip to content

fix(cuda): add F16-K + TURBO-V dispatch cases in flash attention (fixes #83)#84

Open
dentity007 wants to merge 1 commit intoTheTom:feature/turboquant-kv-cachefrom
dentity007:fix/cuda-fattn-f16-turbo-dispatch
Open

fix(cuda): add F16-K + TURBO-V dispatch cases in flash attention (fixes #83)#84
dentity007 wants to merge 1 commit intoTheTom:feature/turboquant-kv-cachefrom
dentity007:fix/cuda-fattn-f16-turbo-dispatch

Conversation

@dentity007
Copy link
Copy Markdown

Summary

Fixes #83. Adds missing F16-K + TURBO-V dispatch cases in the CUDA flash attention vector dispatcher so asymmetric -ctk f16 -ctv turbo{2,3,4} configurations run on GPU instead of aborting at fattn.cu:339.

Background

Issue #83 reports that running llama-cli or llama-bench with -ctk f16 -ctv turbo3 (or turbo2, turbo4) aborts at GGML_ABORT("fatal error") on line 339 of fattn.cu. Reproduced on both sm_89 (RTX 4090) and sm_121 (DGX Spark GB10) during PR #82 validation on Discussion ggml-org#20969.

Root Cause

Same class as PR #82: missing dispatch entries for specific K/V type combinations.

The kernel itself already supports F16+TURBO structurally. fattn-vec.cuh lines 79-80 treat TURBO types as unquantized alongside F16/BF16:

// line 79 (K side)
constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0);
// line 80 (V side)
constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0);

So this is pure dispatch plumbing:

  1. No template instance files for fattn-vec-instance-f16-turbo{2,3,4}_0.cu
  2. No extern declarations in fattn-vec.cuh
  3. No dispatch cases in fattn.cu
  4. No entries in CMakeLists.txt explicit source list (the GGML_CUDA_FA_ALL_QUANTS=OFF default path)

Changes

Pure additive, ~48 lines across 6 files. No modifications to existing logic.

File Change
template-instances/fattn-vec-instance-f16-turbo2_0.cu NEW (5 lines)
template-instances/fattn-vec-instance-f16-turbo3_0.cu NEW (5 lines)
template-instances/fattn-vec-instance-f16-turbo4_0.cu NEW (5 lines)
ggml/src/ggml-cuda/fattn-vec.cuh +9 extern declarations
ggml/src/ggml-cuda/fattn.cu +3 dispatch cases
ggml/src/ggml-cuda/CMakeLists.txt +3 explicit-list entries

Follows the exact pattern of existing Q8_0+TURBO cases.

Scope

This PR covers F16-K + TURBO-V (the exact case from #83).

Out of scope for this PR (observations from validation):

  • TURBO-K + F16-V (reverse direction). Flagging as a finding from validation, not part of this PR. The full -ctk f16,q8_0,turbo3 -ctv f16,q8_0,turbo3 bench matrix aborts at fattn.cu:348 on the turbo3/f16 combination. Separate missing case with a similar fix pattern. Happy to file a separate issue or PR if useful.
  • BF16-K + TURBO-V. Likely same root cause class. Did not verify during this validation.

Testing

Base commit 59798f10d of feature/turboquant-kv-cache (post PR #82 merge). Same patch (sha cca5e8f0...) applied on both architectures.

Coverage matrix:

Architecture Hardware Model FA_ALL_QUANTS
sm_89 RTX 4090 (RunPod, Iceland) Qwen3-30B-A3B Q4_K_M =ON
sm_121 DGX Spark GB10 Qwen3-30B-A3B Q4_K_M =ON
sm_121 DGX Spark GB10 Qwen3-30B-A3B Q4_K_M =OFF
sm_121 DGX Spark GB10 Qwen3-30B-A3B UD-Q4_K_XL =OFF

Both architectures and both models from the original #83 bug report covered. Both build configurations exercised (=ON uses the file(GLOB) path; =OFF uses the explicit source list this PR updates).

Pre-fix crash reproduced (sm_89)

Running the command from issue #83 against base commit 59798f10d:

$ ./build/bin/llama-cli --model Qwen3-30B-A3B-Q4_K_M.gguf -fa 1 -ngl 99 \
    -ctk f16 -ctv turbo3 \
    -p "Explain KV cache quantization in two sentences." -n 64
...
/workspace/llama-cpp-turboquant/ggml/src/ggml-cuda/fattn.cu:339: fatal error
[stack trace through ggml_cuda_flash_attn_ext -> ggml_abort]
Exit code: 134 (SIGABRT)

Same command on the patched build: model loads cleanly without abort. Throughput measurements from llama-bench shown below.

Build verified with both FA_ALL_QUANTS settings

The CMakeLists.txt update is required because:

  • GGML_CUDA_FA_ALL_QUANTS=ON uses file(GLOB template-instances/fattn-vec*.cu) which picks up new files after cmake reconfigure.
  • GGML_CUDA_FA_ALL_QUANTS=OFF (default) uses an explicit source list in ggml-cuda/CMakeLists.txt that must name each instance file.

=ON validated on sm_89 (Iceland pod) and sm_121 (Spark). =OFF validated on sm_121 (Spark) with both Q4_K_M and UD-Q4_K_XL.

Without the CMakeLists.txt change, the =OFF path would fail at link because the explicit source list omits the three new .cu files. The same failure class (undefined symbol ggml_cuda_flash_attn_ext_vec_case<...>) was observed on sm_89 during the initial post-patch build, when an incremental cmake --build did not re-run the GLOB and therefore did not pick up the new files; forcing a cmake reconfigure resolved it. That observation confirms the failure mode the CMakeLists.txt update prevents is real.

sm_89 validation on RTX 4090 — the three newly-enabled combos

Targeted single-combo runs, -ngl 99 -fa 1 -b 2048 -ub 2048 -p 0 -n 128 -pg 8192,128:

ctk ctv tg128 pp8192+tg128
f16 turbo2 217.96 ± 2.18 4365.23 ± 13.04
f16 turbo3 216.52 ± 1.67 4199.34 ± 20.85
f16 turbo4 214.49 ± 1.80 3805.94 ± 10.56

All three run at GPU speeds (contrast: PR #82 CPU-fallback symptom was single-digit t/s). Throughput numbers are in the same band, consistent with the three variants sharing the same unquantized-V kernel branch.

Note: the f16/turbo3 tg128 value here (216.52 ± 1.67) and in the 3x3 matrix below (215.54 ± 0.91) come from two independent bench invocations on the same build. They agree within their error bars; difference is run-to-run variance.

sm_89 validation on RTX 4090 (3x3 matrix at 8K and 32K context)

For context, the symmetric baseline combos (not touched by this PR) on the same build:

ctk ctv tg128 pp8192+tg128 pp32768+tg128
f16 f16 242.57 5014.89 4894.58
f16 q8_0 227.38 4486.28 4464.29
f16 turbo3 215.54 4195.63 4270.61
q8_0 f16 237.59 4777.91 4701.90
q8_0 q8_0 221.81 4265.86 4244.73
q8_0 turbo3 212.33 4040.47 4091.07

f16/turbo3 throughput is within ~11% of the f16/f16 baseline and within ~5% of the f16/q8_0 baseline at both 8K and 32K contexts.

sm_121 validation on DGX Spark GB10

Same patch, same base commit as sm_89. Q4_K_M tested with both FA_ALL_QUANTS settings; UD-Q4_K_XL tested with =OFF. Only f16/turbo3 exercised on sm_121. The fix pattern is symmetric across turbo2, turbo3, turbo4 (identical structural additions in each of the 6 files), and all three were exercised individually on sm_89 (see the three-combos table above). The three variants route through the same unquantized-V branch of the flash-attention kernel.

-ngl 99 -fa 1 -b 2048 -ub 2048 -p 0 -n 128 -pg 8192,128:

Model FA_ALL_QUANTS ctk ctv tg128 pp8192+tg128
Q4_K_M =ON f16 turbo3 83.97 ± 0.36 1245.23 ± 1.39
Q4_K_M =OFF f16 turbo3 83.78 ± 0.41 1242.16 ± 1.26
UD-Q4_K_XL =OFF f16 turbo3 85.11 ± 0.33 1256.04 ± 1.60

=ON vs =OFF on Q4_K_M: 0.2% difference on both metrics, consistent with same code path + run-to-run variance. No abort. No CPU fallback (inferred from throughput: GB10 CPU-path on similar workloads reported ~5 t/s in the original PR #82 bug; these are ~80 t/s, GPU-consistent).

Numerical equivalence (llama-perplexity on wikitext-2 test)

Context 4096, 75 chunks, same model, same seed conditions:

ctk ctv PPL
f16 f16 7.4950 ± 0.05346
f16 turbo3 7.5281 ± 0.05363

Delta: +0.0331 PPL (+0.44%). The ± values are within-run bootstrap estimates from a single perplexity run per config, not across-run variance. The delta is smaller than the within-run error estimate of either measurement, but confirming this with multiple runs of each config would strengthen the claim. One-run result: no meaningful quality shift attributable to the new dispatch path.

KV cache memory footprint

Reported by llama_memory_breakdown_print at 4K context:

ctk ctv KV context memory
f16 f16 384 MiB
f16 turbo3 229 MiB

40.4% reduction in KV context memory for the f16/turbo3 asymmetric configuration. Expected behavior, confirms the turbo3 V path is actually being taken.

Reproduction

Same command as the issue #83 body, using llama-bench:

# Pre-fix (expected: abort at fattn.cu:339 when the dispatcher hits f16/turbo3)
./build/bin/llama-bench \
    -m Qwen3-30B-A3B-Q4_K_M.gguf \
    -fa 1 -t 1 -ngl 99 \
    -p 0 -n 128 -pg 8192,128 \
    -ctk f16 -ctv turbo3

# Post-fix with this PR (expected: bench completes with GPU throughput)
# Same command. The 3x3 matrix results above show the full numbers.

The flash attention vector dispatcher was missing instance files, extern
declarations, and dispatch cases for F16-K + TURBO-V combinations. Any
of `-ctk f16 -ctv turbo{2,3,4}` aborted at fattn.cu:339 on sm_89 and
sm_121 because no dispatch case matched.

The kernel already treats TURBO types as unquantized alongside F16 and
BF16 (fattn-vec.cuh:79-80), so this is a dispatch-plumbing gap, not a
kernel limitation. Pure additive change.

CMakeLists.txt is updated so builds without GGML_CUDA_FA_ALL_QUANTS (the
default) also link the new template instances.

Validated on:
- RTX 4090 (sm_89) with Qwen3-30B-A3B Q4_K_M, FA_ALL_QUANTS=ON
- DGX Spark GB10 (sm_121) with Q4_K_M and UD-Q4_K_XL, both FA_ALL_QUANTS=ON and =OFF

Key results:
- Pre-fix crash reproduced at fattn.cu:339 (exit 134 SIGABRT, sm_89)
- Post-fix: all three new combos (f16/turbo{2,3,4}) run in llama-bench on sm_89
- sm_89 tg128: 215-218 t/s for new combos (vs 242 t/s for f16/f16 baseline)
- sm_121 tg128: 83.78-85.11 t/s for f16/turbo3 across model and build-config combinations
- Single-run PPL on sm_89 wikitext-2: 7.5281 vs 7.4950 f16/f16 baseline (delta 0.44%, within within-run error)
- KV context memory at 4K ctx on sm_89: 229 MiB vs 384 MiB f16/f16 baseline

Fixes TheTom#83
@dentity007
Copy link
Copy Markdown
Author

Heads up for context. Ran additional sm_121 (DGX Spark GB10) validation overnight:

  • Full 3x3 bench matrix at 8K and 32K contexts, 3 repetitions. Across-run variance stays under 1% on every valid cell. The newly-enabled f16/turbo3 combo is 83.06 ± 0.17 t/s tg128, 1240.24 ± 1.54 t/s pp8192+tg128.
  • llama-perplexity context-length sweep (2K, 4K, 8K, 16K, 32K) with f16/f16 vs f16/turbo3. ΔPPL stays under 0.61% across the full range, 0.13% at 32K. No compounding with context.
  • Cross-architecture PPL agrees within 0.05% between sm_89 and sm_121 at the 4K config from the PR body (7.4950 vs 7.4989 for f16/f16, 7.5281 vs 7.5339 for f16/turbo3).

The fattn.cu:348 abort on turbo3/f16 (the reverse-direction case flagged in Scope) reproduced 3/3 reps on sm_121 with the identical stack trace from the Iceland validation.

Happy to paste the full tables as a follow-up comment or open a separate issue/PR for the turbo3/f16 reverse-direction gap, whichever is more useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: F16-K + TURBO-V missing from fattn dispatch, crashes at fattn.cu:339 on sm_89 and sm_121

1 participant