Skip to content

Add TurboQuantKVCache: 3-bit/4-bit KV cache compression for generation#1202

Open
dedalien wants to merge 1 commit intoml-explore:mainfrom
dedalien:turboq/integrate-generic-quant-sdpa
Open

Add TurboQuantKVCache: 3-bit/4-bit KV cache compression for generation#1202
dedalien wants to merge 1 commit intoml-explore:mainfrom
dedalien:turboq/integrate-generic-quant-sdpa

Conversation

@dedalien
Copy link
Copy Markdown

@dedalien dedalien commented Apr 26, 2026

Builds on ml-explore/mlx#3026 (Dan Yeh) — the generic quantized_scaled_dot_product_attention API with pluggable modes.

Needs PR CC-yeh/mlx#3 to be approved and merged before 3026 merge.

What this does

Adds TurboQuantKVCache, a drop-in KV cache that compresses keys and values to 3 or 4 bits during generation using TurboQuant (arXiv 2504.19874):

  1. WHT rotation: K vectors are rotated via Walsh-Hadamard transform, spreading energy uniformly so the distribution approximates N(0,1)
  2. Lloyd-Max quantization: optimal scalar quantizer for N(0,1); codebooks live in the Metal kernel as compile-time constants
  3. Bit packing: indices packed into uint32, same layout as the existing affine SDPA kernel
  4. V is not rotated (kernel output lands directly in the original V space)
  5. Q rotation runs in float32 via a fused Metal kernel before the SDPA call; result is cast to bfloat16 for dispatch. bfloat16 butterfly accumulation across 8 WHT stages shifts softmax peaks on models with large key scales (Qwen3.6, key_scale ≈ 6–10).

Two-phase cache: prefill stores float16; on the first generation step all prefill tokens are batch-compressed. Subsequent steps compress token-by-token.

New files

  • mlx_lm/models/turbo_cache.py: TurboQuantKVCache(_BaseCache), make_turbo_cache(), WHT/encode helpers
  • mlx_lm/models/turbo_metal.py: two fused Metal kernels via mx.fast.metal_kernel (WHT+norm+codebook+pack, and WHT-only for Q rotation), one thread per token, float32 registers

Modified files

  • mlx_lm/models/base.py: detects TurboQuantKVCache via isinstance, routes to _turbo_scaled_dot_product_attention
  • mlx_lm/models/cache.py: make_turbo_cache() replaces KVCache with TurboQuantKVCache; leaves ArraysCache/DeltaNet layers untouched; fp16_layers= keeps first/last N attention layers in float16
  • mlx_lm/generate.py: generate_step gains kv_cache_type= and turbo_fp16_layers=; --kv-cache-type and --turbo-fp16-layers in CLI
  • tests/test_turbo_cache.py: 22 sub-cases covering WHT isometry, encode shapes, two-phase cache, D in {64,128,256}, GQA=6, B=2, bfloat16, fp16_layers boundary

Usage

from mlx_lm import load
from mlx_lm.generate import generate_step

model, tokenizer = load("mlx-community/Qwen3.6-27B-4bit")
tokens = tokenizer.encode("Hello", return_tensors="mlx")[0]

for tok, _ in generate_step(tokens, model, kv_cache_type="turbo3"):
    ...

Or via CLI:

mlx_lm.generate --model mlx-community/Qwen3.6-27B-4bit --kv-cache-type turbo3 --prompt "Hello"

Supported configurations

  • Head dims: 64, 128, 256
  • Bits: 3 (turbo3) or 4 (turbo4)
  • Requires Metal GPU

Tested on

Qwen3.6-27B (head_dim=256, 24Q/4KV, GQA=6) on a 24 GB unified memory Mac. The model fits entirely in RAM with turbo3 vs swapping to disk with fp16, enabling practical long-context generation on memory-constrained hardware (~5x KV cache compression).

Note: Depends on #3026 being merged first.

Builds on ml-explore/mlx#3026 (Dan Yeh) — the generic
quantized_scaled_dot_product_attention API with pluggable modes.
Companion PR: ml-explore/mlx#XXXX (adds turbo3/turbo4 to mlx core).

New files:
- mlx_lm/models/turbo_cache.py: TurboQuantKVCache(_BaseCache).
  Two-phase: prefill stores float16, generation returns
  (packed_uint32, float16_scales) for the fused kernel.
  On first generation step, prefill tokens are batch-compressed.
  K is WHT-rotated before quantization; V is not.
- mlx_lm/models/turbo_metal.py: two fused Metal kernels via
  mx.fast.metal_kernel, one thread per token, float32 registers:
  turbo_encode_metal (WHT + norm + codebook + pack) and
  wht_rotate_metal (used to pre-rotate Q before SDPA).

Modified files:
- mlx_lm/models/base.py: detects TurboQuantKVCache via isinstance,
  routes to _turbo_scaled_dot_product_attention. Q rotation in
  float32 (bfloat16 butterfly shifts softmax peaks on models with
  large key scales).
- mlx_lm/models/cache.py: make_turbo_cache() replaces KVCache with
  TurboQuantKVCache; leaves ArraysCache/DeltaNet layers untouched.
  fp16_layers= keeps first/last N attention layers in float16.
- mlx_lm/generate.py: generate_step gains kv_cache_type= and
  turbo_fp16_layers=; --kv-cache-type and --turbo-fp16-layers in CLI.
- tests/test_turbo_cache.py: 22 sub-cases covering WHT isometry,
  encode shapes, two-phase cache, D in {64,128,256}, gqa_factor=6,
  B=2, bfloat16, fp16_layers boundary.

Supported head dims: 64, 128, 256. Requires Metal GPU.

Tested on Qwen3.6-27B (head_dim=256, 24Q/4KV, GQA=6), 24 GB unified
memory Mac. Enables longer context generation on memory-constrained
hardware by compressing the KV cache ~5x.
@lawcontinue
Copy link
Copy Markdown

Useful for pipeline parallel setups where KV cache gets shipped between nodes — bandwidth becomes the bottleneck at that point, not compute. A 3-bit cache would cut transfer size significantly. Question: does the WHT rotation happen at cache store time or at attention compute time? If store-time, the compression is free during generation; if compute-time, the rotation cost offsets some of the savings.

@dedalien
Copy link
Copy Markdown
Author

Useful for pipeline parallel setups where KV cache gets shipped between nodes — bandwidth becomes the bottleneck at that point, not compute. A 3-bit cache would cut transfer size significantly. Question: does the WHT rotation happen at cache store time or at attention compute time? If store-time, the compression is free during generation; if compute-time, the rotation cost offsets some of the savings.

Yo, WHT happens at compute time for Q. K is rotated and compressed at store time. Cost a bit, so maybe implement a context length threshold to enable turbo3 compression, (e.g. turbo_kv_start=1024 in generate_step)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants