Add Mistral Small 4 (119B MoE) support with absorbed MLA#1075
Conversation
|
Follow-up: eliminate O(L²) memory spike during prefill After opening this PR, I copied the files in oMLX and used opencode). Testing with longer prompts revealed periodic RAM spikes: roughly 40 GB when reaching 25k tokens. It could cause system crashes on very long contexts: it did for my 128GB Mac. My friend Claude and I investigated this and here's the improvement. Root cause: the RoPE score bias was pre-computed as an explicit Fix (commit 1fcaacc): replace we can pass a plain causal mask to SDPA and let its tiled flash-attention kernel handle the score computation without materialising the Updated numbers (M5 Max,
Memory growth across all runs is about 180 MB, which is (ideally) only the compressed KV cache filling. No more spikes 😅 . Small throughput cost (~3–4 tok/s vs previous) from the extra test command: python -m mlx_lm.generate \
--model ~/.lmstudio/models/sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit \
--prompt "Write ..."See details below. Now it works much more consistently, even though for agentic coding at least, I'm not very impressed by this 4bit quant of Mistral for agentic coding, but that's unrelated. Detailspython -m mlx_lm.generate --model ~/.lmstudio/models/sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
<frozen runpy>:128: RuntimeWarning: 'mlx_lm.generate' found in sys.modules after import of package 'mlx_lm', but prior to execution of 'mlx_lm.generate'; this may result in unpredictable behaviour
Calling `python -m mlx_lm.generate...` directly is deprecated. Use `mlx_lm.generate...` or `python -m mlx_lm generate ...` instead.
==========
### **Multi-Head Latent Attention (MLA) in Modern Large Language Models: A Detailed Technical Explanation**
Multi-Head Latent Attention (MLA) is a recent innovation in attention mechanisms designed to improve memory efficiency and computational scalability in large language models (LLMs), particularly for long-context scenarios. It builds upon **Multi-Head Attention (MHA)** and **Grouped-Query Attention (GQA)** by introducing a **low-rank latent space** for key-value (KV) projection, enabling more efficient inference through **weight absorption** and **separate rotary position embeddings (RoPE)**. Below, we dissect its mechanics, compare it to standard MHA and GQA, and analyze its trade-offs.
---
## **1. Background: Standard Multi-Head Attention (MHA)**
Before explaining MLA, let’s recap **Multi-Head Attention (MHA)**, introduced in the original **Transformer** architecture (Vaswani et al., 2017). MHA decomposes attention into multiple parallel heads, each with its own set of query (Q), key (K), and value (V) projections, allowing the model to attend to different parts of the input simultaneously.
### **Key Steps in MHA:**
1. **Projection into Heads:**
- For an input tensor \( X \in \mathbb{R}^{n \times d} \) (where \( n \) = sequence length, \( d \) = hidden dimension), MHA computes:
\[
Q = X W_Q, \quad K = X W_K, \quad V = X W_V
\]
where \( W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k} \) are learnable weight matrices, and \( d_k \) is the head dimension (typically \( d_k = d / h \), where \( h \) is the number of heads).
- Each head independently computes attention scores:
\[
A_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i
\]
where \( Q_i, K_i, V_i \) are the \( i \)-th head’s projections.
2. **Concatenation & Final Projection:**
- The outputs of all heads are concatenated:
\[
O = \text{Concat}(A_1, A_2
==========
Prompt: 103 tokens, 90.167 tokens-per-sec
Generation: 512 tokens, 91.355 tokens-per-sec
Peak memory: 67.059 GB |
|
@graelo thanks for the improvement. It definitely slipped my mind! I found a bug when running with the 4-bit MLX checkpoint ( Error: Root cause: The checkpoint's quantization config has per-layer entries like Fix (tested and working):
Verified on M5 Max 128GB. Model loads and generates correctly with this fix applied. |
|
Hi ❤️. I've also noticed a couple of issues when trying to load the 16 bit model for quantization, and fixed them in the last couple hours (we're in sync), but I haven't pushed the commit yet. My change seems much larger than yours, so I'll stash it and try use yours for quant and inference. I'll keep you updated. Thanks! |
|
Three more commits since the last update. FP8 checkpoint support (f2c448e): trying to convert from the original
Conversion now works, and so does inference using my initial testing model sachin-sith/.. and my quants: 4-bit and 5-bit affine (group_size 32 and 64) and mxfp4. MoEGate fix (0386313): loading Fix: swap the raw weight for Credit to @ProducerGuy for finding the testing model, making the diagnosis and offering the fix: I applied what he suggested in his comment above. Sanitize cleanup (02d9cd7): the method had grown to a large size (100+ lines) with FP8 dequant logic split across two places. Extracted into 4 helpers:
PS: here are some truncated runs of Detailstime python -m mlx_lm generate --model ~/.lmstudio/models/graelo/Mistral-Small-4-119B-2603-MLX-mxfp4 --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
==========
<model beautiful output>
==========
Prompt: 103 tokens, 92.624 tokens-per-sec
Generation: 512 tokens, 92.757 tokens-per-sec
Peak memory: 63.341 GB
python -m mlx_lm generate --model --prompt --max-tokens 512 3.96s user 9.66s system 104% cpu 13.063 totalpython -m mlx_lm generate --model ~/.lmstudio/models/sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
==========
<amazing response>
==========
Prompt: 103 tokens, 74.720 tokens-per-sec
Generation: 512 tokens, 91.488 tokens-per-sec
Peak memory: 67.059 GBpython -m mlx_lm generate --model ~/.lmstudio/models/mlx-community/Mistral-Small-4-119B-2603-4bit --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
==========
<fantastic generation>
==========
Prompt: 103 tokens, 394.978 tokens-per-sec
Generation: 512 tokens, 90.854 tokens-per-sec
Peak memory: 67.059 GBDon't believe the last prefill token rate (speed of light), it's simply because the model's loading time is shorter, and it bypasses part of the sanitization process. The wall time is the same as the other models, about 13 sec (load, pre-fill, generation). |
Applied graelo's absorbed MLA from PR ml-explore#1075 (3 files). Fixed MoEGate quantization crash: nn.Linear + dequant in sanitize. Benchmark: 101.9 tok/s generation, 181 tok/s prompt, 66.98 GB peak.
Moved query scaling from Python (separate dispatch) into Metal kernel (applied at query load time, sdpa_vector.h pattern). Eliminated 0.181ms of dispatch overhead per step per layer. 107.3 tok/s — exceeds Phase 1 absorbed (101.9) and nearly matches original Phase 0 (108.3), while keeping INT4 cache (57x compression). Beats graelo's PR ml-explore#1075 (90-92 tok/s) by 17%. Two-line kernel change + call site update. Zero risk.
Mistral Small 4 uses Multi-head Latent Attention (MLA) like DeepSeek V2/V3. This implementation uses weight absorption: kv_b_proj is decomposed into embed_q (W_UK) and unembed_out (W_UV) MultiLinear weights at load time, so the KV cache stores only the compressed latent (320 floats/token/layer) instead of the full decompressed K/V (8192 floats/token/layer) — a ~25x reduction. Tested with sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit: - 8 tokens: 107 tok/s, 66.97 GB peak - 512 tokens: 93 tok/s, 67.04 GB peak - 2048 tokens: 90 tok/s, 67.13 GB peak
Replace the pre-computed RoPE score bias (a full (B,H,L,S) float tensor) with a unified Q/K concatenation. dot(q_nope,k_nope)+dot(q_pe,k_pe) equals dot(concat(q_nope,q_pe), concat(k_nope,k_pe)), so passing a plain causal mask to mx.fast.scaled_dot_product_attention lets it use its tiled flash-attention kernel and never materialise the score matrix in RAM. For a 25K-token prefill on M5 Max this eliminates ~40 GB of transient allocation, removing the memory spikes observed during inference. Small throughput cost (~3-4 tok/s) vs previous approach.
The original Mistral Small 4 checkpoint uses per-tensor FP8 scales (scalar scale_inv) and pre-stacked fused expert weights (experts.gate_up_proj / experts.down_proj) with their own _scale_inv keys that the existing weight_scale_inv loop does not match. Add a scalar fast path in _dequant_fp8 and a new branch in sanitize to dequant, split gate_up_proj into gate/up, and map to switch_mlp keys. The existing per-expert path (pre-quantized MLX checkpoints) is preserved in the else branch.
MoEGate used a raw mx.zeros weight that MLX's nn.quantize cannot handle, causing a ValueError when loading checkpoints like mlx-community/Mistral-Small-4-119B-2603-4bit which quantize the gate at 8-bit. Replace the raw weight with nn.Linear, remap gate.weight → gate.linear.weight in sanitize, and dequantize the gate back to full precision (~1 MB/layer) so nn.quantize skips it entirely. Keeps routing accuracy at negligible memory cost. Credit to ProducerGuy for the diagnosis and fix approach.
sanitize grew to 140 lines handling three checkpoint formats with FP8 dequant split across two places. Extract into: - _dequant_fp8_weights: single pass matching all _scale_inv keys - _sanitize_experts: fused and per-expert formats → switch_mlp - _sanitize_gate: remap + dequant for nn.Linear gate - _sanitize_kv_b_proj: absorbed MLA decomposition Also fix stale docstring on Mistral4Attention and drop unused List import.
02d9cd7 to
7e1e1c0
Compare
|
Simple rebase on |
|
I can now close this PR as @ProducerGuy built upon it and pushed it further 🚀. All good! |
Summary
This adds support for Mistral Small 4 (119B MoE) using absorbed Multi-head Latent Attention, following the same pattern already established in
deepseek_v3.py.Mistral Small 4 ships its KV decompression as a single
kv_b_projlinear layer. Duringsanitize, we decompose it intoembed_q(W_UK) andunembed_out(W_UV)MultiLinearweights, which lets the KV cache store only the compressed latent and RoPE component — 320 floats/token/layer instead of 8,192 with full decompression (~25× smaller).The rest of the model (MoE gating, shared experts, SwitchGLU routing) is straightforward and closely follows the existing codebase conventions.
Changes
mlx_lm/models/mistral4.py— new model file with absorbed MLA attention, MoE, and FP8 dequantization supportmlx_lm/models/mistral3.py— routemistral4text config (or configs withn_routed_experts) to the new moduleTested with
sachin-sith/Mistral-Small-4-119B-2603-MLX-4biton M5 Max (128 GB):Outputs are coherent and factually reasonable. Memory growth stays minimal as the cache fills, which is consistent with the compressed KV cache working as expected.
Also tested successfully with
inferencerlabs/Mistral-Small-4-119B-2603-MLX-4.5bit.Context
Related to #1037, which also adds Mistral Small 4 support but uses non-absorbed MLA (caches full decompressed K/V). See the discussion there for more background on the trade-off.