Skip to content

fix: 8-bit dequant for MLX mixed-precision gate quantization#14

Open
userFRM wants to merge 1 commit intodanveloper:mainfrom
userFRM:fix/8bit-gate-dequant
Open

fix: 8-bit dequant for MLX mixed-precision gate quantization#14
userFRM wants to merge 1 commit intodanveloper:mainfrom
userFRM:fix/8bit-gate-dequant

Conversation

@userFRM
Copy link
Copy Markdown

@userFRM userFRM commented Mar 23, 2026

Problem

MLX 4-bit quantized models use 8-bit precision for routing gates, specified per-tensor in config.json:

"quantization": {
    "bits": 4, "group_size": 64,
    "model.layers.0.mlp.gate": {"group_size": 64, "bits": 8},
    "model.layers.0.mlp.shared_expert_gate": {"group_size": 64, "bits": 8}
}

The 4-bit dequant kernel extracts 8 nibbles per uint32, but these tensors pack 4 bytes per uint32 (8-bit). This corrupts routing gate scores, selecting wrong experts every layer, producing nonsensical output.

Verification

Compared gate output against MLX Python reference for mlx-community/Qwen3-Coder-Next-4bit:

  • Without fix: gate scores have wrong magnitudes and signs (RMS 1.2 vs MLX 6.8)
  • With fix (CPU path): gate scores match MLX exactly (same top expert indices, same score range)

Forced full CPU computation (g_metal = NULL) confirmed coherent output: "2 + 2 = 4", correct code generation, proper EOS handling.

Changes

shaders.metal: Added dequant_matvec_8bit kernel — same tiled ROWS_PER_TG=8 structure as dequant_matvec_4bit_v3, but extracts 4 bytes per uint32 with & 0xFF instead of 8 nibbles with & 0xF. FMA-optimized with precomputed scale*x and bias*x.

infer.m:

  • Added int bits field to BatchMatvecSpec for per-tensor bit-width dispatch
  • Added matvec_8bit pipeline state to MetalCtx
  • Added cpu_dequant_matvec_8bit CPU fallback
  • Updated gpu_encode_batch_matvec and gpu_batch_matvec to select 8-bit kernel when bits == 8
  • Marked gate_w and seg_w (shared_expert_gate) as bits=8 in all 7 BatchMatvecSpec initialization sites

Impact

Affects any MLX quantized model with per-tensor bit-width overrides in the quantization config. This is standard for Qwen3 family models.

Fixes #10

Test plan

  • ./infer --prompt "Hello" --tokens 20 --k 4 produces coherent output
  • Gate scores match MLX reference (--timing shows reasonable routing)
  • No regression on 2-bit mode (--2bit)
  • Shader compiles on M1/M2/M3/M4

MLX 4-bit models quantize routing gates (mlp.gate, mlp.shared_expert_gate)
at 8-bit precision, specified per-tensor in config.json. The inference
engine treated all tensors as 4-bit, extracting 8 nibbles per uint32 from
data that actually packs 4 bytes per uint32. This corrupts routing scores,
selecting wrong experts and producing nonsensical output.

Changes:
- Add dequant_matvec_8bit Metal kernel (4 bytes/uint32, FMA-optimized)
- Add cpu_dequant_matvec_8bit CPU fallback
- Add BatchMatvecSpec.bits field for per-tensor bit-width dispatch
- Mark gate and shared_expert_gate as 8-bit in all dispatch sites

Fixes danveloper#10

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nonsensical output on Apple M4 Pro (Mac Mini 64GB) — 14.5 tok/s but garbage generation

1 participant