Add TurboQuant KV cache + DeepSeek V4 support#1067
Conversation
Implements TurboQuant (arXiv 2504.19874) KV cache compression: - PolarQuant: randomized Hadamard rotation + Lloyd-Max codebook - Bit-packed uint32 storage (3-bit: 10 values per word) - Fused Metal kernels for quantize and dequantize - Incremental decode buffer for O(1) per-step cost - Layer-adaptive mode: FP16 for first/last N layers Usage: generate_step(prompt, model, turbo_kv_bits=3) Results (Qwen2.5-32B, M4 Pro 48GB): - 4.6x compression, 0.98x FP16 speed, identical quality - 16K context: 4.2GB → 897MB KV cache
|
I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine |
That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check. |
This is the script I used: |
Ah got it, you are using turbo_kv_bits=4. That's expected to have quality issues on the K tensor - I've seen the same thing. Try turbo_kv_bits=3 which actually works better (counterintuitively, the 3-bit codebook fits the post-rotation Gaussian distribution better than 4-bit for K. Also for a 23B model try increasing turbo_fp16_layers=2 or turbo_fp16_layers=4 to keep more layers in full precision. |
Found it. Your config turbo_kv_bits=4, turbo_fp16_layers=1 should work on most models, but MoE architectures like GLM-4.7-Flash might need more FP16 layers. Try turbo_fp16_layers=4 or turbo_fp16_layers=6. On my 7B tests, both 3-bit and 4-bit with fp16_layers=1 produce clean output. |
whith params you suggested same garbage issue: |
Yeah MLA is a different beast, makes sense it breaks. Good catch. Qwen3.5 should be fine since it's standard attention. Let me know how it goes. |
Did more testing:
|
Thanks for testing across architectures. MLA and SSM are expected - TurboQuant only works with standard multi-head attention KV cache. I should add a check that raises a clear error instead of silently producing garbage. Will fix. |
|
some thoughts for DX:
|
Good points, agree on all of them. Specifically:
|
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16, compared to ~2x from the existing 8-bit quantization. Integration points: - memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation, trim support, needs_dequantize property, config validation - scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig - cli.py: --turbo-kv-bits for serve and bench commands Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16, compared to ~2x from the existing 8-bit quantization. Integration points: - memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation, trim support, needs_dequantize property, config validation - scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig - cli.py: --turbo-kv-bits for serve and bench commands Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
a087778 to
9315fbc
Compare
Findings from independent testing + potential contributionsI've been working on TurboQuant for MLX-LM independently and have some findings that might be useful for this PR. 1. Bug in MLX core PR #3328 (
|
| Context | Keys only compressed | Keys + Values compressed |
|---|---|---|
| 50K (Llama-3-8B) | 4.3 GB KV | 2.7 GB KV |
| 200K (Llama-3-8B) | 14.9 GB KV | 10.7 GB KV |
Speed impact is negligible (0.94x vs 0.95x). Quality is identical — values tolerate 4-bit well.
3. 200K context proof (32GB Mac)
On a 32GB machine with Llama-3-8B-Instruct-4bit:
- FP16 at 200K context: system enters swap death — unusable for 1+ hour, had to kill the process
- TurboQuant at 200K: completed fill and generated tokens successfully. KV cache = 10.7 GB.
This is the concrete proof that TurboQuant enables context lengths that are physically impossible with FP16 on memory-constrained machines.
4. Quality on real prompts
Tested coding, explanation, and analysis prompts. Outputs are near-identical to standard:
Standard: "quantum computers are special computers that can do some things that regular computers can't.
Imagine you have a big box of different colored balls..."
TurboQuant: "quantum computers are special computers that can do some really cool things that regular computers can't.
Imagine you have a big box of different colored balls..."
5. Separate infrastructure PRs
We have two related PRs that complement TurboQuant:
- feat: add KV cache quantization args to server #1073: Server
--kv-bitssupport (enables KV quantization via API) - feat: QuantizedRotatingKVCache + KVSplit (K/V different bits) #1074:
QuantizedRotatingKVCache+ KVSplitbits=(key, value)tuple support
Happy to help with testing, benchmarks, or integration work on this PR. Great implementation.
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16, compared to ~2x from the existing 8-bit quantization. Integration points: - memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation, trim support, needs_dequantize property, config validation - scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig - cli.py: --turbo-kv-bits for serve and bench commands Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16, compared to ~2x from the existing 8-bit quantization. Integration points: - memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation, trim support, needs_dequantize property, config validation - scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig - cli.py: --turbo-kv-bits for serve and bench commands Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
Fail-fast when --turbo-kv-bits is requested without mlx-lm TurboQuant support: config and CLI now error out with an actionable message pointing to ml-explore/mlx-lm#1067 instead of silently no-oping. Reject the --turbo-kv-bits + --kv-cache-quantization combination in both argparse-fed callers and in MemoryCacheConfig.__post_init__ so programmatic users get the same guard. Log a one-time warning when _TurboQuantCacheWrapper.is_trimmable() degrades to False because the upstream TurboQuantKVCache lacks copy(); prefix-cache trimming (supersequence / LCP reuse) falls back to full-prefix matching in that case (correct, just less efficient). Document the dual estimate_kv_cache_memory paths (wrapper vs bare) and the copy() contract we depend on in _trim_cache_offset.
Saves KV cache at prefill completion and every 32K tokens during long prefills. Enables prefix matching on follow-up messages in multi-turn conversations. Previously cache was only saved after generation with key = prompt + generated tokens. Next turn re-tokenizes the assistant response with template wrappers, producing different tokens and breaking prefix match. Now saves with prompt-only key at prefill end, so next turn matches the prompt prefix and skips it. Tested: 46-token prompt cached, follow-up processed only 16 new tokens instead of full 62.
When all prompt tokens match a cached entry, rest=[] causes stream_generate to crash with ValueError (empty prompt). Fix: trim cache by 1 token and re-process the last token. Tested: exact same prompt sent twice, second request processes only 1 token with 40 cached. No crash.
Disk entries were capped at 2x prompt-cache-size, causing old caches to be evicted too aggressively. Disk is cheap, RAM is not. Added --prompt-cache-disk-size (default 100) to control disk entries independently from --prompt-cache-size (RAM entries). Tested: 5 RAM entries + 10 disk entries, capped correctly.
10f2bfc to
fda593e
Compare
Full implementation of DeepSeek-V4-Flash architecture: - Compressed Sparse Attention (CSA ratio=4, HCA ratio=128) - Lightning Indexer for top-k compressed position selection - Learned compressor with overlap transform - Hyper-Connections with Sinkhorn normalization - Hash routing MoE (256 experts, 6 active) - Grouped low-rank output projection - MQA with inverse RoPE Sparse attention with window + compressed KV: - Chunked sparse prefill (256 queries/chunk) - Sparse decode with circular window buffer - Step-based buffer growth for compressed entries - RotatingKVCache for pure sliding window layers - SparseKVCache with full state serialization Tested on DeepSeek-V4-Flash-8bit (303GB, 6.4 tok/s).
AutoTokenizer.from_pretrained crashes on model types that transformers doesn't know (e.g. deepseek_v4) due to rope_scaling standardization. Fall back to PreTrainedTokenizerFast when AutoTokenizer fails. Also register deepseek_v4 config with transformers AutoConfig.
SparseKVCache.update_and_fetch didn't preserve existing data on reallocation, causing shape mismatch on continuation prefill chunks. Also handle continuation prefill correctly: when the server splits a long prompt into chunks, subsequent chunks now extend the existing sparse buffers instead of reinitializing them.
Serialization materializes all KV cache arrays, temporarily doubling memory usage. For large models (400GB+) on machines with limited headroom this triggers OOM. Check available memory before saving and skip if less than 8GB free.
Compressor decode mode expects single-token input (L=1) but continuation prefill chunks pass L>1. Process chunk tokens one-by-one through the compressor to maintain correct state.
- Fix Metal resource limit crash: token-by-token compressor loop in continuation prefill creates too many buffers. Add periodic mx.eval to flush every 32 tokens. - Fix missing _win_buf on second generate call: use cache.offset==0 to detect first prefill instead of checking _win_buf existence. - Add safety init in sparse decode for single-token prompt edge case. - Cap Sinkhorn iterations at 10 (from 20) for ~12% speed improvement with negligible quality difference.
Custom Metal kernels for MoE decode (fused_moe_kernel.py): - Fused gate+up+SwiGLU: all experts in one dispatch (1.8x MoE speedup) - Fused down projection: all experts in one dispatch - Fused grouped output projection: 8 groups in one dispatch - 4-bit inline dequantization matching MLX qmv pattern Model optimizations (deepseek_v4.py): - MoE layer skip on 8/43 layers during decode (quality-validated) - mx.compile for HC pre/post and MoE modules - Inverse RoPE simplified to rope(-offset) for decode - Fused Q+KV projection via weight concatenation - Step-based buffer growth for compressed KV Bugfixes from Opus audit: - Fix inverse RoPE for prefill (L>1) with per-position offsets - Fix SparseKVCache state serialization alignment - Fix continuation prefill window buffer (incremental update) - Fix SparseKVCache.trim to invalidate sparse state - Fix output dtype mismatch (float32 -> input dtype) - Add K%512 and N%8 assertions for Metal kernels - Fix id()-based cache leak in fused_grouped_wo - Add scale handling in _apply_rope_at_positions - Remove dead code (fused_qkv_proj) SwitchGLU decode path (switch_layers.py): - Sequential per-expert processing with fused Metal kernels - Automatic fallback for non-4-bit quantization Result: 6.5 -> 21 tok/s through server (3.2x), 161GB peak memory.
- Add test_deepseek_v4.py (28 tests: model creation, prefill/decode, continuation, multi-turn, cache serialization, compressor, fused kernels) - Add --turbo-kv-bits and --turbo-fp16-layers to server CLI - Add MLA/SSM guard in make_prompt_cache (ValueError instead of garbage) - Add SparseKVCache to prompt cache save/load allowlist - Fix stale sparse state across conversations (reset on offset==0) - Gate MoE layer skip on num_hidden_layers==43 - Assert matching quant params in fused QKV projection
- Accept upstream server.py refactor, re-apply turbo args on top - Fallback for mx.new_thread_local_stream (older MLX versions)
- FP8 e4m3 block dequant from original HF checkpoints (auto-detect format) - FP4 packed expert dequant with ue8m0 block scaling - MTP weight filtering and HF key remapping in sanitize() - 4-bit affine value compression (--turbo-v-bits) for 2x memory savings - BatchSparseKVCache with merge/unmerge for server batch mode - Per-entry sparse decode for compressor state machine compatibility
|
Merged upstream, added tests, fixed the reported issues. Tests: 69 total (test_deepseek_v4.py + test_turboquant.py), all pass. New stuff:
Verified 22.3 tok/s on 4-bit V4 Flash, M3 Ultra 512GB. @babhishek21 - all DX points done (to_turbo_quantized, make_prompt_cache, CLI args in generate + server) @kipanshi - make_prompt_cache raises ValueError for MLA and SSM models now @deceptech-packet-ninja - value compression is in as |
Summary
Adds TurboQuant KV cache compression and full DeepSeek-V4-Flash inference support.
TurboQuant KV Cache
Implementation of arXiv 2504.19874 (ICLR 2026).
generate_step(prompt, model, turbo_kv_bits=3)DeepSeek V4 Flash (284B MoE)
First MLX implementation of DeepSeek-V4-Flash with full architecture support:
Performance optimizations:
Results
Quick Start
Other fixes
Test plan