Skip to content

Add Mistral Small 4 (119B MoE) support with absorbed MLA#1075

Closed
graelo wants to merge 5 commits into
ml-explore:mainfrom
graelo:mistral4-absorbed-mla
Closed

Add Mistral Small 4 (119B MoE) support with absorbed MLA#1075
graelo wants to merge 5 commits into
ml-explore:mainfrom
graelo:mistral4-absorbed-mla

Conversation

@graelo
Copy link
Copy Markdown

@graelo graelo commented Mar 30, 2026

Summary

This adds support for Mistral Small 4 (119B MoE) using absorbed Multi-head Latent Attention, following the same pattern already established in deepseek_v3.py.

Mistral Small 4 ships its KV decompression as a single kv_b_proj linear layer. During sanitize, we decompose it into embed_q (W_UK) and unembed_out (W_UV) MultiLinear weights, which lets the KV cache store only the compressed latent and RoPE component — 320 floats/token/layer instead of 8,192 with full decompression (~25× smaller).

The rest of the model (MoE gating, shared experts, SwitchGLU routing) is straightforward and closely follows the existing codebase conventions.

Changes

  • mlx_lm/models/mistral4.py — new model file with absorbed MLA attention, MoE, and FP8 dequantization support
  • mlx_lm/models/mistral3.py — route mistral4 text config (or configs with n_routed_experts) to the new module

Tested with

sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit on M5 Max (128 GB):

Generation length tok/s Peak memory
8 tokens 107 66.97 GB
512 tokens 93 67.04 GB
2048 tokens 90 67.13 GB

Outputs are coherent and factually reasonable. Memory growth stays minimal as the cache fills, which is consistent with the compressed KV cache working as expected.

Also tested successfully with inferencerlabs/Mistral-Small-4-119B-2603-MLX-4.5bit.

Context

Related to #1037, which also adds Mistral Small 4 support but uses non-absorbed MLA (caches full decompressed K/V). See the discussion there for more background on the trade-off.

@graelo
Copy link
Copy Markdown
Author

graelo commented Mar 31, 2026

Follow-up: eliminate O(L²) memory spike during prefill

After opening this PR, I copied the files in oMLX and used opencode). Testing with longer prompts revealed periodic RAM spikes: roughly 40 GB when reaching 25k tokens. It could cause system crashes on very long contexts: it did for my 128GB Mac. My friend Claude and I investigated this and here's the improvement.

Root cause: the RoPE score bias was pre-computed as an explicit (B, H, L, S) float tensor (pe_scores) and passed as an additive mask to mx.fast.scaled_dot_product_attention. That tensor has to live in RAM before the Metal kernel runs, scaling quadratically with sequence length (~40 GB at 25K tokens).

Fix (commit 1fcaacc): replace pe_scores with a unified Q/K concatenation. Since

dot(q_nope, k_nope) + dot(q_pe, k_pe) = dot(concat(q_nope, q_pe), concat(k_nope, k_pe))

we can pass a plain causal mask to SDPA and let its tiled flash-attention kernel handle the score computation without materialising the (B, H, L, S) matrix in RAM at all.

Updated numbers (M5 Max, sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit):

Prompt tokens Generation tok/s Peak memory
22 103.6 66.979 GB
103 91.3 67.059 GB
205 86.1 67.161 GB

Memory growth across all runs is about 180 MB, which is (ideally) only the compressed KV cache filling. No more spikes 😅 . Small throughput cost (~3–4 tok/s vs previous) from the extra concatenate ops.

test command:

python -m mlx_lm.generate \
    --model ~/.lmstudio/models/sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit \
    --prompt "Write ..."

See details below.

Now it works much more consistently, even though for agentic coding at least, I'm not very impressed by this 4bit quant of Mistral for agentic coding, but that's unrelated.

Details
python -m mlx_lm.generate --model ~/.lmstudio/models/sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
<frozen runpy>:128: RuntimeWarning: 'mlx_lm.generate' found in sys.modules after import of package 'mlx_lm', but prior to execution of 'mlx_lm.generate'; this may result in unpredictable behaviour
Calling `python -m mlx_lm.generate...` directly is deprecated. Use `mlx_lm.generate...` or `python -m mlx_lm generate ...` instead.
==========
### **Multi-Head Latent Attention (MLA) in Modern Large Language Models: A Detailed Technical Explanation**

Multi-Head Latent Attention (MLA) is a recent innovation in attention mechanisms designed to improve memory efficiency and computational scalability in large language models (LLMs), particularly for long-context scenarios. It builds upon **Multi-Head Attention (MHA)** and **Grouped-Query Attention (GQA)** by introducing a **low-rank latent space** for key-value (KV) projection, enabling more efficient inference through **weight absorption** and **separate rotary position embeddings (RoPE)**. Below, we dissect its mechanics, compare it to standard MHA and GQA, and analyze its trade-offs.

---

## **1. Background: Standard Multi-Head Attention (MHA)**
Before explaining MLA, let’s recap **Multi-Head Attention (MHA)**, introduced in the original **Transformer** architecture (Vaswani et al., 2017). MHA decomposes attention into multiple parallel heads, each with its own set of query (Q), key (K), and value (V) projections, allowing the model to attend to different parts of the input simultaneously.

### **Key Steps in MHA:**
1. **Projection into Heads:**
   - For an input tensor \( X \in \mathbb{R}^{n \times d} \) (where \( n \) = sequence length, \( d \) = hidden dimension), MHA computes:
     \[
     Q = X W_Q, \quad K = X W_K, \quad V = X W_V
     \]
     where \( W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k} \) are learnable weight matrices, and \( d_k \) is the head dimension (typically \( d_k = d / h \), where \( h \) is the number of heads).
   - Each head independently computes attention scores:
     \[
     A_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i
     \]
     where \( Q_i, K_i, V_i \) are the \( i \)-th head’s projections.

2. **Concatenation & Final Projection:**
   - The outputs of all heads are concatenated:
     \[
     O = \text{Concat}(A_1, A_2
==========
Prompt: 103 tokens, 90.167 tokens-per-sec
Generation: 512 tokens, 91.355 tokens-per-sec
Peak memory: 67.059 GB

@ProducerGuy
Copy link
Copy Markdown

@graelo thanks for the improvement. It definitely slipped my mind! I found a bug when running with the 4-bit MLX checkpoint (mlx-community/Mistral-Small-4-119B-2603-4bit):

Error:

ValueError: Unable to quantize model of type <class 'mlx_lm.models.mistral4.MoEGate'>

Root cause: The checkpoint's quantization config has per-layer entries like "mlp.gate": {"group_size": 64, "bits": 8}. MLX's quantizer matches this path and tries to call MoEGate.to_quantized(), which doesn't exist on the raw mx.zeros weight.

Fix (tested and working):

  1. MoEGate.__init__: replace self.weight = mx.zeros(...) with self.linear = nn.Linear(..., bias=False)
  2. MoEGate.__call__: replace x @ self.weight.T with self.linear(x)
  3. In sanitize(): remap checkpoint keys from mlp.gate.weight to mlp.gate.linear.weight and dequantize from INT8 to float16 (gate is ~1MB/layer, negligible)

Verified on M5 Max 128GB. Model loads and generates correctly with this fix applied.

@graelo
Copy link
Copy Markdown
Author

graelo commented Apr 1, 2026

Hi ❤️. I've also noticed a couple of issues when trying to load the 16 bit model for quantization, and fixed them in the last couple hours (we're in sync), but I haven't pushed the commit yet. My change seems much larger than yours, so I'll stash it and try use yours for quant and inference. I'll keep you updated. Thanks!

@graelo
Copy link
Copy Markdown
Author

graelo commented Apr 1, 2026

Three more commits since the last update.

FP8 checkpoint support (f2c448e): trying to convert from the original mistralai/Mistral-Small-4-119B-2603 was crashing for two reasons:

  • Some FP8 scale_inv tensors are scalar (per-tensor scaling), but _dequant_fp8 assumed 2D block-scaled arrays.
  • The expert weights use a pre-stacked fused format (experts.gate_up_proj / experts.down_proj) with _scale_inv keys that don't contain weight_scale_inv, so the FP8 dequant loop just skipped them.

Conversion now works, and so does inference using my initial testing model sachin-sith/.. and my quants: 4-bit and 5-bit affine (group_size 32 and 64) and mxfp4.

MoEGate fix (0386313): loading mlx-community/Mistral-Small-4-119B-2603-4bit was crashing with ValueError: Unable to quantize model of type MoEGate. That checkpoint quantizes the routing gate at 8-bit, but MoEGate used a raw mx.zeros weight that nn.quantize can't handle.

Fix: swap the raw weight for nn.Linear, remap the keys in sanitize, and dequantize the gate back to full precision (~1 MB/layer, negligible) so nn.quantize does not touch the gate's weights. Routing stays at float16 accuracy.

Credit to @ProducerGuy for finding the testing model, making the diagnosis and offering the fix: I applied what he suggested in his comment above.

Sanitize cleanup (02d9cd7): the method had grown to a large size (100+ lines) with FP8 dequant logic split across two places. Extracted into 4 helpers:

  • _dequant_fp8_weights: single pass matching all _scale_inv keys
  • _sanitize_experts: fused and per-expert formats → switch_mlp
  • _sanitize_gate: key remap + dequant for the nn.Linear gate
  • _sanitize_kv_b_proj: absorbed MLA decomposition

sanitize itself is now ~13 lines because it merely calls these helpers. No functional change: all these checkpoint formats now work:

  • original mistralai/Mistral-Small-4-119B-2603 (for quantization, I "only" have 128 GB of RAM 😅 )
  • sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit
  • mlx-community/Mistral-Small-4-119B-2603-4bit model pointed above
  • my quants listed above

PS: here are some truncated runs of python -m mlx_lm generate ..., but I easily imagine that testing it from the branch is more convincing!

Details
time python -m mlx_lm generate --model ~/.lmstudio/models/graelo/Mistral-Small-4-119B-2603-MLX-mxfp4 --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
==========
<model beautiful output>
==========
Prompt: 103 tokens, 92.624 tokens-per-sec
Generation: 512 tokens, 92.757 tokens-per-sec
Peak memory: 63.341 GB
python -m mlx_lm generate --model  --prompt  --max-tokens 512  3.96s user 9.66s system 104% cpu 13.063 total
python -m mlx_lm generate --model ~/.lmstudio/models/sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
==========
<amazing response>
==========
Prompt: 103 tokens, 74.720 tokens-per-sec
Generation: 512 tokens, 91.488 tokens-per-sec
Peak memory: 67.059 GB
python -m mlx_lm generate --model ~/.lmstudio/models/mlx-community/Mistral-Small-4-119B-2603-4bit --prompt "Write a detailed technical explanation of how Multi-head Latent Attention works in modern large language models. Cover the compression of key-value pairs into a low-rank latent space, the role of rotary position embeddings on a separate subspace, and the weight absorption trick that allows inference systems to cache only the compressed representation. Compare this approach to standard multi-head attention and grouped-query attention in terms of memory efficiency and computational trade-offs. Be thorough and precise." --max-tokens 512
==========
<fantastic generation>
==========
Prompt: 103 tokens, 394.978 tokens-per-sec
Generation: 512 tokens, 90.854 tokens-per-sec
Peak memory: 67.059 GB

Don't believe the last prefill token rate (speed of light), it's simply because the model's loading time is shorter, and it bypasses part of the sanitization process. The wall time is the same as the other models, about 13 sec (load, pre-fill, generation).

ProducerGuy pushed a commit to ProducerGuy/mlx-lm that referenced this pull request Apr 2, 2026
Applied graelo's absorbed MLA from PR ml-explore#1075 (3 files).
Fixed MoEGate quantization crash: nn.Linear + dequant in sanitize.
Benchmark: 101.9 tok/s generation, 181 tok/s prompt, 66.98 GB peak.
ProducerGuy pushed a commit to ProducerGuy/mlx-lm that referenced this pull request Apr 2, 2026
Moved query scaling from Python (separate dispatch) into Metal kernel
(applied at query load time, sdpa_vector.h pattern). Eliminated 0.181ms
of dispatch overhead per step per layer.

107.3 tok/s — exceeds Phase 1 absorbed (101.9) and nearly matches
original Phase 0 (108.3), while keeping INT4 cache (57x compression).

Beats graelo's PR ml-explore#1075 (90-92 tok/s) by 17%.

Two-line kernel change + call site update. Zero risk.
graelo added 5 commits April 4, 2026 20:43
Mistral Small 4 uses Multi-head Latent Attention (MLA) like DeepSeek V2/V3.
This implementation uses weight absorption: kv_b_proj is decomposed into
embed_q (W_UK) and unembed_out (W_UV) MultiLinear weights at load time,
so the KV cache stores only the compressed latent (320 floats/token/layer)
instead of the full decompressed K/V (8192 floats/token/layer) — a ~25x
reduction.

Tested with sachin-sith/Mistral-Small-4-119B-2603-MLX-4bit:
- 8 tokens:    107 tok/s, 66.97 GB peak
- 512 tokens:   93 tok/s, 67.04 GB peak
- 2048 tokens:  90 tok/s, 67.13 GB peak
Replace the pre-computed RoPE score bias (a full (B,H,L,S) float tensor)
with a unified Q/K concatenation. dot(q_nope,k_nope)+dot(q_pe,k_pe) equals
dot(concat(q_nope,q_pe), concat(k_nope,k_pe)), so passing a plain causal
mask to mx.fast.scaled_dot_product_attention lets it use its tiled
flash-attention kernel and never materialise the score matrix in RAM.

For a 25K-token prefill on M5 Max this eliminates ~40 GB of transient
allocation, removing the memory spikes observed during inference.
Small throughput cost (~3-4 tok/s) vs previous approach.
The original Mistral Small 4 checkpoint uses per-tensor FP8 scales
(scalar scale_inv) and pre-stacked fused expert weights
(experts.gate_up_proj / experts.down_proj) with their own _scale_inv
keys that the existing weight_scale_inv loop does not match.

Add a scalar fast path in _dequant_fp8 and a new branch in sanitize
to dequant, split gate_up_proj into gate/up, and map to switch_mlp
keys. The existing per-expert path (pre-quantized MLX checkpoints)
is preserved in the else branch.
MoEGate used a raw mx.zeros weight that MLX's nn.quantize cannot
handle, causing a ValueError when loading checkpoints like
mlx-community/Mistral-Small-4-119B-2603-4bit which quantize the
gate at 8-bit.

Replace the raw weight with nn.Linear, remap gate.weight →
gate.linear.weight in sanitize, and dequantize the gate back to
full precision (~1 MB/layer) so nn.quantize skips it entirely.
Keeps routing accuracy at negligible memory cost.

Credit to ProducerGuy for the diagnosis and fix approach.
sanitize grew to 140 lines handling three checkpoint formats
with FP8 dequant split across two places. Extract into:

- _dequant_fp8_weights: single pass matching all _scale_inv keys
- _sanitize_experts: fused and per-expert formats → switch_mlp
- _sanitize_gate: remap + dequant for nn.Linear gate
- _sanitize_kv_b_proj: absorbed MLA decomposition

Also fix stale docstring on Mistral4Attention and drop unused
List import.
@graelo graelo force-pushed the mistral4-absorbed-mla branch from 02d9cd7 to 7e1e1c0 Compare April 4, 2026 18:50
@graelo
Copy link
Copy Markdown
Author

graelo commented Apr 4, 2026

Simple rebase on main, after checking everything works the same.

@graelo
Copy link
Copy Markdown
Author

graelo commented Apr 4, 2026

I can now close this PR as @ProducerGuy built upon it and pushed it further 🚀. All good!

@graelo graelo closed this Apr 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants