Add TurboQuantKVCache: 3-bit/4-bit KV cache compression for generation#1202
Add TurboQuantKVCache: 3-bit/4-bit KV cache compression for generation#1202dedalien wants to merge 1 commit intoml-explore:mainfrom
Conversation
Builds on ml-explore/mlx#3026 (Dan Yeh) — the generic quantized_scaled_dot_product_attention API with pluggable modes. Companion PR: ml-explore/mlx#XXXX (adds turbo3/turbo4 to mlx core). New files: - mlx_lm/models/turbo_cache.py: TurboQuantKVCache(_BaseCache). Two-phase: prefill stores float16, generation returns (packed_uint32, float16_scales) for the fused kernel. On first generation step, prefill tokens are batch-compressed. K is WHT-rotated before quantization; V is not. - mlx_lm/models/turbo_metal.py: two fused Metal kernels via mx.fast.metal_kernel, one thread per token, float32 registers: turbo_encode_metal (WHT + norm + codebook + pack) and wht_rotate_metal (used to pre-rotate Q before SDPA). Modified files: - mlx_lm/models/base.py: detects TurboQuantKVCache via isinstance, routes to _turbo_scaled_dot_product_attention. Q rotation in float32 (bfloat16 butterfly shifts softmax peaks on models with large key scales). - mlx_lm/models/cache.py: make_turbo_cache() replaces KVCache with TurboQuantKVCache; leaves ArraysCache/DeltaNet layers untouched. fp16_layers= keeps first/last N attention layers in float16. - mlx_lm/generate.py: generate_step gains kv_cache_type= and turbo_fp16_layers=; --kv-cache-type and --turbo-fp16-layers in CLI. - tests/test_turbo_cache.py: 22 sub-cases covering WHT isometry, encode shapes, two-phase cache, D in {64,128,256}, gqa_factor=6, B=2, bfloat16, fp16_layers boundary. Supported head dims: 64, 128, 256. Requires Metal GPU. Tested on Qwen3.6-27B (head_dim=256, 24Q/4KV, GQA=6), 24 GB unified memory Mac. Enables longer context generation on memory-constrained hardware by compressing the KV cache ~5x.
|
Useful for pipeline parallel setups where KV cache gets shipped between nodes — bandwidth becomes the bottleneck at that point, not compute. A 3-bit cache would cut transfer size significantly. Question: does the WHT rotation happen at cache store time or at attention compute time? If store-time, the compression is free during generation; if compute-time, the rotation cost offsets some of the savings. |
Yo, WHT happens at compute time for Q. K is rotated and compressed at store time. Cost a bit, so maybe implement a context length threshold to enable turbo3 compression, (e.g. turbo_kv_start=1024 in generate_step) |
Builds on ml-explore/mlx#3026 (Dan Yeh) — the generic
quantized_scaled_dot_product_attentionAPI with pluggable modes.Needs PR CC-yeh/mlx#3 to be approved and merged before 3026 merge.
What this does
Adds
TurboQuantKVCache, a drop-in KV cache that compresses keys and values to 3 or 4 bits during generation using TurboQuant (arXiv 2504.19874):uint32, same layout as the existing affine SDPA kernelTwo-phase cache: prefill stores float16; on the first generation step all prefill tokens are batch-compressed. Subsequent steps compress token-by-token.
New files
mlx_lm/models/turbo_cache.py:TurboQuantKVCache(_BaseCache),make_turbo_cache(), WHT/encode helpersmlx_lm/models/turbo_metal.py: two fused Metal kernels viamx.fast.metal_kernel(WHT+norm+codebook+pack, and WHT-only for Q rotation), one thread per token, float32 registersModified files
mlx_lm/models/base.py: detectsTurboQuantKVCacheviaisinstance, routes to_turbo_scaled_dot_product_attentionmlx_lm/models/cache.py:make_turbo_cache()replacesKVCachewithTurboQuantKVCache; leavesArraysCache/DeltaNet layers untouched;fp16_layers=keeps first/last N attention layers in float16mlx_lm/generate.py:generate_stepgainskv_cache_type=andturbo_fp16_layers=;--kv-cache-typeand--turbo-fp16-layersin CLItests/test_turbo_cache.py: 22 sub-cases covering WHT isometry, encode shapes, two-phase cache, D in {64,128,256}, GQA=6, B=2, bfloat16, fp16_layers boundaryUsage
Or via CLI:
mlx_lm.generate --model mlx-community/Qwen3.6-27B-4bit --kv-cache-type turbo3 --prompt "Hello"Supported configurations
turbo3) or 4 (turbo4)Tested on
Qwen3.6-27B (head_dim=256, 24Q/4KV, GQA=6) on a 24 GB unified memory Mac. The model fits entirely in RAM with turbo3 vs swapping to disk with fp16, enabling practical long-context generation on memory-constrained hardware (~5x KV cache compression).