Skip to content

Add TurboQuant KV cache + DeepSeek V4 support#1067

Open
arozanov wants to merge 52 commits into
ml-explore:mainfrom
arozanov:feature/turboquant-kv-cache
Open

Add TurboQuant KV cache + DeepSeek V4 support#1067
arozanov wants to merge 52 commits into
ml-explore:mainfrom
arozanov:feature/turboquant-kv-cache

Conversation

@arozanov
Copy link
Copy Markdown

@arozanov arozanov commented Mar 28, 2026

Summary

Adds TurboQuant KV cache compression and full DeepSeek-V4-Flash inference support.

TurboQuant KV Cache

Implementation of arXiv 2504.19874 (ICLR 2026).

  • 4.6x compression at 3-bit (10 values packed per uint32)
  • 0.98x FP16 speed on Qwen2.5-32B (M4 Pro 48GB)
  • Fused Metal kernels for quantize/dequantize
  • Drop-in: generate_step(prompt, model, turbo_kv_bits=3)

DeepSeek V4 Flash (284B MoE)

First MLX implementation of DeepSeek-V4-Flash with full architecture support:

  • Compressed Sparse Attention (CSA ratio=4) with Lightning Indexer
  • Heavily Compressed Attention (HCA ratio=128) with learned compressor
  • Hyper-Connections with Sinkhorn normalization
  • Hash routing MoE (256 experts, 6 active)
  • Grouped low-rank output projection
  • MQA with inverse RoPE

Performance optimizations:

  • Custom fused Metal kernels for MoE decode (gate+up+SwiGLU, down proj, grouped wo_a)
  • MoE layer skip (8/43 layers, quality-validated)
  • mx.compile for HC and MoE modules
  • Sparse prefill with chunked processing
  • SparseKVCache with full state serialization
  • Disk-backed prompt cache with memory-aware saving

Results

Model Quant tok/s RAM Hardware
DeepSeek-V4-Flash (284B) 4-bit 21 161 GB Mac Studio M3 Ultra 512GB
DeepSeek-V4-Flash (284B) 8-bit 8.5 303 GB Mac Studio M3 Ultra 512GB
Qwen2.5-32B FP16 + TQ 3-bit KV 26 22 GB MacBook Pro M4 Pro 48GB

Quick Start

pip install git+https://github.com/arozanov/mlx-lm.git@feature/turboquant-kv-cache
huggingface-cli download mlx-community/deepseek-ai-DeepSeek-V4-Flash-4bit --local-dir models/v4-4bit
mlx_lm.server --model models/v4-4bit --host 127.0.0.1 --port 8080 --prompt-cache-size 5 --no-batch

Other fixes

  • Tokenizer fallback for unrecognized model types (AutoTokenizer -> PreTrainedTokenizerFast)
  • Disk cache memory check (skip save when system RAM is low)
  • Stream threading fix (generation on main thread)
  • Multi-turn cache reuse via prefill checkpoints
  • Chunked prefill crash fix for compressed layers

Test plan

  • 4-bit server: streaming, non-streaming, multi-turn
  • 8-bit server: code generation, math, reasoning
  • Chunked prefill (2K+ token prompts)
  • Cache serialization save/restore
  • SparseKVCache trim
  • Unit tests (prefill, continuation, decode, second conversation)
  • Opus audit: all critical/important issues fixed

Implements TurboQuant (arXiv 2504.19874) KV cache compression:
- PolarQuant: randomized Hadamard rotation + Lloyd-Max codebook
- Bit-packed uint32 storage (3-bit: 10 values per word)
- Fused Metal kernels for quantize and dequantize
- Incremental decode buffer for O(1) per-step cost
- Layer-adaptive mode: FP16 for first/last N layers

Usage:
  generate_step(prompt, model, turbo_kv_bits=3)

Results (Qwen2.5-32B, M4 Pro 48GB):
- 4.6x compression, 0.98x FP16 speed, identical quality
- 16K context: 4.2GB → 897MB KV cache
@kipanshi
Copy link
Copy Markdown

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

@arozanov
Copy link
Copy Markdown
Author

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

@kipanshi
Copy link
Copy Markdown

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

This is the script I used:

#!/bin/bash
# Run GLM-4.7-Flash-REAP-23B-A3B with TurboQuant KV cache
# Optimized for M1 Max 32GB

MODEL_DIR="$HOME/my_docs/llms/GLM-4.7-Flash-REAP-23B-A3B-mlx-mxfp4"
MLX_LM_DIR="$HOME/opt/mlx-lm"

TURBO_KV_BITS="${TURBO_KV_BITS:-4}"       # 3-bit = 4.6x compression, 4-bit = safer quality
TURBO_FP16_LAYERS="${TURBO_FP16_LAYERS:-1}" # first/last N layers stay FP16
MAX_TOKENS="${MAX_TOKENS:-4096}"
TEMP="${TEMP:-0.7}"
TOP_P="${TOP_P:-0.9}"

PROMPT="${1:-Hello, who are you?}"

cd "$MLX_LM_DIR" || exit 1

uv run python -c "
from mlx_lm import load, stream_generate
from mlx_lm.generate import make_sampler
import sys

model, tokenizer = load('${MODEL_DIR}')

sampler = make_sampler(temp=${TEMP}, top_p=${TOP_P})

prompt = sys.argv[1]
if tokenizer.has_chat_template:
    messages = [{'role': 'user', 'content': prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
    )

for response in stream_generate(
    model,
    tokenizer,
    prompt=prompt,
    max_tokens=${MAX_TOKENS},
    sampler=sampler,
    turbo_kv_bits=${TURBO_KV_BITS},
    turbo_fp16_layers=${TURBO_FP16_LAYERS},
):
    print(response.text, end='', flush=True)
print()
" "$PROMPT"

@arozanov
Copy link
Copy Markdown
Author

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

This is the script I used:

#!/bin/bash
# Run GLM-4.7-Flash-REAP-23B-A3B with TurboQuant KV cache
# Optimized for M1 Max 32GB

MODEL_DIR="$HOME/my_docs/llms/GLM-4.7-Flash-REAP-23B-A3B-mlx-mxfp4"
MLX_LM_DIR="$HOME/opt/mlx-lm"

TURBO_KV_BITS="${TURBO_KV_BITS:-4}"       # 3-bit = 4.6x compression, 4-bit = safer quality
TURBO_FP16_LAYERS="${TURBO_FP16_LAYERS:-1}" # first/last N layers stay FP16
MAX_TOKENS="${MAX_TOKENS:-4096}"
TEMP="${TEMP:-0.7}"
TOP_P="${TOP_P:-0.9}"

PROMPT="${1:-Hello, who are you?}"

cd "$MLX_LM_DIR" || exit 1

uv run python -c "
from mlx_lm import load, stream_generate
from mlx_lm.generate import make_sampler
import sys

model, tokenizer = load('${MODEL_DIR}')

sampler = make_sampler(temp=${TEMP}, top_p=${TOP_P})

prompt = sys.argv[1]
if tokenizer.has_chat_template:
    messages = [{'role': 'user', 'content': prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
    )

for response in stream_generate(
    model,
    tokenizer,
    prompt=prompt,
    max_tokens=${MAX_TOKENS},
    sampler=sampler,
    turbo_kv_bits=${TURBO_KV_BITS},
    turbo_fp16_layers=${TURBO_FP16_LAYERS},
):
    print(response.text, end='', flush=True)
print()
" "$PROMPT"

Ah got it, you are using turbo_kv_bits=4. That's expected to have quality issues on the K tensor - I've seen the same thing. Try turbo_kv_bits=3 which actually works better (counterintuitively, the 3-bit codebook fits the post-rotation Gaussian distribution better than 4-bit for K. Also for a 23B model try increasing turbo_fp16_layers=2 or turbo_fp16_layers=4 to keep more layers in full precision.

@arozanov
Copy link
Copy Markdown
Author

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

This is the script I used:

#!/bin/bash
# Run GLM-4.7-Flash-REAP-23B-A3B with TurboQuant KV cache
# Optimized for M1 Max 32GB

MODEL_DIR="$HOME/my_docs/llms/GLM-4.7-Flash-REAP-23B-A3B-mlx-mxfp4"
MLX_LM_DIR="$HOME/opt/mlx-lm"

TURBO_KV_BITS="${TURBO_KV_BITS:-4}"       # 3-bit = 4.6x compression, 4-bit = safer quality
TURBO_FP16_LAYERS="${TURBO_FP16_LAYERS:-1}" # first/last N layers stay FP16
MAX_TOKENS="${MAX_TOKENS:-4096}"
TEMP="${TEMP:-0.7}"
TOP_P="${TOP_P:-0.9}"

PROMPT="${1:-Hello, who are you?}"

cd "$MLX_LM_DIR" || exit 1

uv run python -c "
from mlx_lm import load, stream_generate
from mlx_lm.generate import make_sampler
import sys

model, tokenizer = load('${MODEL_DIR}')

sampler = make_sampler(temp=${TEMP}, top_p=${TOP_P})

prompt = sys.argv[1]
if tokenizer.has_chat_template:
    messages = [{'role': 'user', 'content': prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
    )

for response in stream_generate(
    model,
    tokenizer,
    prompt=prompt,
    max_tokens=${MAX_TOKENS},
    sampler=sampler,
    turbo_kv_bits=${TURBO_KV_BITS},
    turbo_fp16_layers=${TURBO_FP16_LAYERS},
):
    print(response.text, end='', flush=True)
print()
" "$PROMPT"

Found it. Your config turbo_kv_bits=4, turbo_fp16_layers=1 should work on most models, but MoE architectures like GLM-4.7-Flash might need more FP16 layers. Try turbo_fp16_layers=4 or turbo_fp16_layers=6. On my 7B tests, both 3-bit and 4-bit with fp16_layers=1 produce clean output.

@kipanshi
Copy link
Copy Markdown

Ah got it, you are using turbo_kv_bits=4. That's expected to have quality issues on the K tensor - I've seen the same thing. Try turbo_kv_bits=3 which actually works better (counterintuitively, the 3-bit codebook fits the post-rotation Gaussian distribution better than 4-bit for K. Also for a 23B model try increasing turbo_fp16_layers=2 or turbo_fp16_layers=4 to keep more layers in full precision.

whith params you suggested same garbage issue:
"""
Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA
architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and
turboquant's PolarQuant rotation likely can't handle that compressed representation correctly.
"""
I will try to test it with Qwen3.5 35B MoE

@arozanov
Copy link
Copy Markdown
Author

Ah got it, you are using turbo_kv_bits=4. That's expected to have quality issues on the K tensor - I've seen the same thing. Try turbo_kv_bits=3 which actually works better (counterintuitively, the 3-bit codebook fits the post-rotation Gaussian distribution better than 4-bit for K. Also for a 23B model try increasing turbo_fp16_layers=2 or turbo_fp16_layers=4 to keep more layers in full precision.

whith params you suggested same garbage issue: """ Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and turboquant's PolarQuant rotation likely can't handle that compressed representation correctly. """ I will try to test it with Qwen3.5 35B MoE

Yeah MLA is a different beast, makes sense it breaks. Good catch. Qwen3.5 should be fine since it's standard attention. Let me know how it goes.

@kipanshi
Copy link
Copy Markdown

whith params you suggested same garbage issue: """ Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and turboquant's PolarQuant rotation likely can't handle that compressed representation correctly. """ I will try to test it with Qwen3.5 35B MoE

Yeah MLA is a different beast, makes sense it breaks. Good catch. Qwen3.5 should be fine since it's standard attention. Let me know how it goes.

Did more testing:

  • GLM-4.7-Flash (MLA): loads but produces garbage (MLA latent cache incompatible)
  • Qwen3.5-35B-A3B (hybrid SSM/attention): crashes (SSM cache not supported)
  • Standard attention models (Llama, Mistral): works correctly

@arozanov
Copy link
Copy Markdown
Author

whith params you suggested same garbage issue: """ Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and turboquant's PolarQuant rotation likely can't handle that compressed representation correctly. """ I will try to test it with Qwen3.5 35B MoE

Yeah MLA is a different beast, makes sense it breaks. Good catch. Qwen3.5 should be fine since it's standard attention. Let me know how it goes.

Did more testing:

  • GLM-4.7-Flash (MLA): loads but produces garbage (MLA latent cache incompatible)
  • Qwen3.5-35B-A3B (hybrid SSM/attention): crashes (SSM cache not supported)
  • Standard attention models (Llama, Mistral): works correctly

Thanks for testing across architectures. MLA and SSM are expected - TurboQuant only works with standard multi-head attention KV cache. I should add a check that raises a clear error instead of silently producing garbage. Will fix.

@babhishek21
Copy link
Copy Markdown

some thoughts for DX:

  1. Since this is a compression scheme, perhaps it should be given the same treatment as KVCache#to_quantized(). The basic unbounded KVCache could have a to_turbo_quantized() (or equivalent) that returns a TurboQuantKVCache.
  2. Possibility to have a generalized convert_to_turbo_quantized() function (similar to how [Experimental] Add TurboQuantKVCache: PolarQuant KV cache compression at 2-4 bits #1059 does it), with supporting cache specializations progressively adopting to_turbo_quantized() (where supported).
  3. make_prompt_cache should still be the entry point for feature enablement related to caches; in this particular case whether to enable TurboQuant compression or not. Similar to how max_kv_size causes a switch to bounded cache, params turboq_kv_bits and turboq_fp16_layers could switch on TurboQuant. Would also help dedupe all the logic around if hasattr(model, "make_cache"): return model.make_cache().
  4. Lib users are still able to pass in any custom prompt_cache to generate.
  5. CLI users should be able to pass in TurboQuant args in the same way they pass in --max-kv-size.

@arozanov
Copy link
Copy Markdown
Author

some thoughts for DX:

  1. Since this is a compression scheme, perhaps it should be given the same treatment as KVCache#to_quantized(). The basic unbounded KVCache could have a to_turbo_quantized() (or equivalent) that returns a TurboQuantKVCache.
  2. Possibility to have a generalized convert_to_turbo_quantized() function (similar to how [Experimental] Add TurboQuantKVCache: PolarQuant KV cache compression at 2-4 bits #1059 does it), with supporting cache specializations progressively adopting to_turbo_quantized() (where supported).
  3. make_prompt_cache should still be the entry point for feature enablement related to caches; in this particular case whether to enable TurboQuant compression or not. Similar to how max_kv_size causes a switch to bounded cache, params turboq_kv_bits and turboq_fp16_layers could switch on TurboQuant. Would also help dedupe all the logic around if hasattr(model, "make_cache"): return model.make_cache().
  4. Lib users are still able to pass in any custom prompt_cache to generate.
  5. CLI users should be able to pass in TurboQuant args in the same way they pass in --max-kv-size.

Good points, agree on all of them. Specifically:

  1. to_turbo_quantized() on KVCache - makes sense, will add
  2. Routing through make_prompt_cache instead of separate function - cleaner, agreed
  3. CLI args --turbo-kv-bits and --turbo-fp16-layers alongside --max-kv-size - will do
    I'll rework the PR to follow the existing patterns. Thanks for the detailed review.

@babhishek21
Copy link
Copy Markdown

@arozanov I think you'll need to add tests.
@awni @andresy with that, I think this PR will probably supersede #1059

arozanov pushed a commit to arozanov/vllm-mlx that referenced this pull request Mar 29, 2026
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries
using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16,
compared to ~2x from the existing 8-bit quantization.

Integration points:
- memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation,
  trim support, needs_dequantize property, config validation
- scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig
- cli.py: --turbo-kv-bits for serve and bench commands

Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
arozanov added a commit to arozanov/vllm-mlx that referenced this pull request Mar 29, 2026
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries
using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16,
compared to ~2x from the existing 8-bit quantization.

Integration points:
- memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation,
  trim support, needs_dequantize property, config validation
- scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig
- cli.py: --turbo-kv-bits for serve and bench commands

Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
@arozanov arozanov force-pushed the feature/turboquant-kv-cache branch from a087778 to 9315fbc Compare March 29, 2026 16:04
@QROST
Copy link
Copy Markdown

QROST commented Apr 1, 2026

#1064 #1063 #1059

@deceptech-packet-ninja
Copy link
Copy Markdown

Findings from independent testing + potential contributions

I've been working on TurboQuant for MLX-LM independently and have some findings that might be useful for this PR.

1. Bug in MLX core PR #3328 (turboquant_sdpa kernel)

The kernel dispatches N = k_norms.shape(1) at line 450 of scaled_dot_product_attention.cpp, but this reads the head dimension instead of the sequence length. Should be k_norms.shape(2). After fixing, the kernel produces exact matches (1.000 cosine similarity). I commented on PR #3328 with the fix and benchmarks.

2. Value compression (4-bit alongside 3-bit keys)

The current implementation stores values as FP16. Adding 4-bit affine quantization for values (via standard mx.quantize) doubles the memory savings:

Context Keys only compressed Keys + Values compressed
50K (Llama-3-8B) 4.3 GB KV 2.7 GB KV
200K (Llama-3-8B) 14.9 GB KV 10.7 GB KV

Speed impact is negligible (0.94x vs 0.95x). Quality is identical — values tolerate 4-bit well.

3. 200K context proof (32GB Mac)

On a 32GB machine with Llama-3-8B-Instruct-4bit:

  • FP16 at 200K context: system enters swap death — unusable for 1+ hour, had to kill the process
  • TurboQuant at 200K: completed fill and generated tokens successfully. KV cache = 10.7 GB.

This is the concrete proof that TurboQuant enables context lengths that are physically impossible with FP16 on memory-constrained machines.

4. Quality on real prompts

Tested coding, explanation, and analysis prompts. Outputs are near-identical to standard:

Standard:  "quantum computers are special computers that can do some things that regular computers can't.
            Imagine you have a big box of different colored balls..."
TurboQuant: "quantum computers are special computers that can do some really cool things that regular computers can't.
            Imagine you have a big box of different colored balls..."

5. Separate infrastructure PRs

We have two related PRs that complement TurboQuant:

Happy to help with testing, benchmarks, or integration work on this PR. Great implementation.

Thump604 pushed a commit to arozanov/vllm-mlx that referenced this pull request Apr 11, 2026
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries
using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16,
compared to ~2x from the existing 8-bit quantization.

Integration points:
- memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation,
  trim support, needs_dequantize property, config validation
- scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig
- cli.py: --turbo-kv-bits for serve and bench commands

Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
arozanov added a commit to arozanov/vllm-mlx that referenced this pull request Apr 16, 2026
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries
using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16,
compared to ~2x from the existing 8-bit quantization.

Integration points:
- memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation,
  trim support, needs_dequantize property, config validation
- scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig
- cli.py: --turbo-kv-bits for serve and bench commands

Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
arozanov added a commit to arozanov/vllm-mlx that referenced this pull request Apr 16, 2026
Fail-fast when --turbo-kv-bits is requested without mlx-lm TurboQuant
support: config and CLI now error out with an actionable message
pointing to ml-explore/mlx-lm#1067 instead of silently no-oping.

Reject the --turbo-kv-bits + --kv-cache-quantization combination in
both argparse-fed callers and in MemoryCacheConfig.__post_init__ so
programmatic users get the same guard.

Log a one-time warning when _TurboQuantCacheWrapper.is_trimmable()
degrades to False because the upstream TurboQuantKVCache lacks copy();
prefix-cache trimming (supersequence / LCP reuse) falls back to
full-prefix matching in that case (correct, just less efficient).

Document the dual estimate_kv_cache_memory paths (wrapper vs bare) and
the copy() contract we depend on in _trim_cache_offset.
Saves KV cache at prefill completion and every 32K tokens during
long prefills. Enables prefix matching on follow-up messages in
multi-turn conversations.

Previously cache was only saved after generation with key =
prompt + generated tokens. Next turn re-tokenizes the assistant
response with template wrappers, producing different tokens and
breaking prefix match. Now saves with prompt-only key at prefill
end, so next turn matches the prompt prefix and skips it.

Tested: 46-token prompt cached, follow-up processed only 16 new
tokens instead of full 62.
When all prompt tokens match a cached entry, rest=[] causes
stream_generate to crash with ValueError (empty prompt). Fix:
trim cache by 1 token and re-process the last token.

Tested: exact same prompt sent twice, second request processes
only 1 token with 40 cached. No crash.
Disk entries were capped at 2x prompt-cache-size, causing old
caches to be evicted too aggressively. Disk is cheap, RAM is not.

Added --prompt-cache-disk-size (default 100) to control disk
entries independently from --prompt-cache-size (RAM entries).

Tested: 5 RAM entries + 10 disk entries, capped correctly.
@arozanov arozanov force-pushed the feature/turboquant-kv-cache branch from 10f2bfc to fda593e Compare April 30, 2026 00:03
JianweiChen2021 pushed a commit to JianweiChen2021/mlx-lm that referenced this pull request May 2, 2026
…latest main)

Resolve mlx_lm/server.py using prior merge artifact 7aeb6df (fda593e + ed1fca4).

Made-with: Cursor
arozanov added 8 commits May 4, 2026 23:10
Full implementation of DeepSeek-V4-Flash architecture:
- Compressed Sparse Attention (CSA ratio=4, HCA ratio=128)
- Lightning Indexer for top-k compressed position selection
- Learned compressor with overlap transform
- Hyper-Connections with Sinkhorn normalization
- Hash routing MoE (256 experts, 6 active)
- Grouped low-rank output projection
- MQA with inverse RoPE

Sparse attention with window + compressed KV:
- Chunked sparse prefill (256 queries/chunk)
- Sparse decode with circular window buffer
- Step-based buffer growth for compressed entries
- RotatingKVCache for pure sliding window layers
- SparseKVCache with full state serialization

Tested on DeepSeek-V4-Flash-8bit (303GB, 6.4 tok/s).
AutoTokenizer.from_pretrained crashes on model types that transformers
doesn't know (e.g. deepseek_v4) due to rope_scaling standardization.
Fall back to PreTrainedTokenizerFast when AutoTokenizer fails.

Also register deepseek_v4 config with transformers AutoConfig.
SparseKVCache.update_and_fetch didn't preserve existing data on
reallocation, causing shape mismatch on continuation prefill chunks.

Also handle continuation prefill correctly: when the server splits
a long prompt into chunks, subsequent chunks now extend the existing
sparse buffers instead of reinitializing them.
Serialization materializes all KV cache arrays, temporarily doubling
memory usage. For large models (400GB+) on machines with limited
headroom this triggers OOM. Check available memory before saving
and skip if less than 8GB free.
Compressor decode mode expects single-token input (L=1) but
continuation prefill chunks pass L>1. Process chunk tokens
one-by-one through the compressor to maintain correct state.
- Fix Metal resource limit crash: token-by-token compressor loop in
  continuation prefill creates too many buffers. Add periodic mx.eval
  to flush every 32 tokens.
- Fix missing _win_buf on second generate call: use cache.offset==0
  to detect first prefill instead of checking _win_buf existence.
- Add safety init in sparse decode for single-token prompt edge case.
- Cap Sinkhorn iterations at 10 (from 20) for ~12% speed improvement
  with negligible quality difference.
Custom Metal kernels for MoE decode (fused_moe_kernel.py):
- Fused gate+up+SwiGLU: all experts in one dispatch (1.8x MoE speedup)
- Fused down projection: all experts in one dispatch
- Fused grouped output projection: 8 groups in one dispatch
- 4-bit inline dequantization matching MLX qmv pattern

Model optimizations (deepseek_v4.py):
- MoE layer skip on 8/43 layers during decode (quality-validated)
- mx.compile for HC pre/post and MoE modules
- Inverse RoPE simplified to rope(-offset) for decode
- Fused Q+KV projection via weight concatenation
- Step-based buffer growth for compressed KV

Bugfixes from Opus audit:
- Fix inverse RoPE for prefill (L>1) with per-position offsets
- Fix SparseKVCache state serialization alignment
- Fix continuation prefill window buffer (incremental update)
- Fix SparseKVCache.trim to invalidate sparse state
- Fix output dtype mismatch (float32 -> input dtype)
- Add K%512 and N%8 assertions for Metal kernels
- Fix id()-based cache leak in fused_grouped_wo
- Add scale handling in _apply_rope_at_positions
- Remove dead code (fused_qkv_proj)

SwitchGLU decode path (switch_layers.py):
- Sequential per-expert processing with fused Metal kernels
- Automatic fallback for non-4-bit quantization

Result: 6.5 -> 21 tok/s through server (3.2x), 161GB peak memory.
@arozanov arozanov changed the title Add TurboQuant KV cache compression (3-bit, 4.6x) Add TurboQuant KV cache + DeepSeek V4 support May 8, 2026
arozanov added 4 commits May 8, 2026 11:29
- Add test_deepseek_v4.py (28 tests: model creation, prefill/decode,
  continuation, multi-turn, cache serialization, compressor, fused kernels)
- Add --turbo-kv-bits and --turbo-fp16-layers to server CLI
- Add MLA/SSM guard in make_prompt_cache (ValueError instead of garbage)
- Add SparseKVCache to prompt cache save/load allowlist
- Fix stale sparse state across conversations (reset on offset==0)
- Gate MoE layer skip on num_hidden_layers==43
- Assert matching quant params in fused QKV projection
- Accept upstream server.py refactor, re-apply turbo args on top
- Fallback for mx.new_thread_local_stream (older MLX versions)
- FP8 e4m3 block dequant from original HF checkpoints (auto-detect format)
- FP4 packed expert dequant with ue8m0 block scaling
- MTP weight filtering and HF key remapping in sanitize()
- 4-bit affine value compression (--turbo-v-bits) for 2x memory savings
- BatchSparseKVCache with merge/unmerge for server batch mode
- Per-entry sparse decode for compressor state machine compatibility
@arozanov
Copy link
Copy Markdown
Author

arozanov commented May 8, 2026

Merged upstream, added tests, fixed the reported issues.

Tests: 69 total (test_deepseek_v4.py + test_turboquant.py), all pass.

New stuff:

  • FP8 e4m3 dequant from original HF checkpoints (auto-detects, mlx-community weights still work)
  • --turbo-v-bits for 4-bit value compression
  • BatchSparseKVCache - server works without --no-batch
  • Compat with older MLX (tested on 0.29.3)

Verified 22.3 tok/s on 4-bit V4 Flash, M3 Ultra 512GB.

@babhishek21 - all DX points done (to_turbo_quantized, make_prompt_cache, CLI args in generate + server)

@kipanshi - make_prompt_cache raises ValueError for MLA and SSM models now

@deceptech-packet-ninja - value compression is in as --turbo-v-bits

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants