Skip to content

Native MTP speculative decoding (Qwen3.5/3.6 reference implementation)#990

Open
AirRunner wants to merge 26 commits into
ml-explore:mainfrom
AirRunner:feat/mtp-native
Open

Native MTP speculative decoding (Qwen3.5/3.6 reference implementation)#990
AirRunner wants to merge 26 commits into
ml-explore:mainfrom
AirRunner:feat/mtp-native

Conversation

@AirRunner
Copy link
Copy Markdown

@AirRunner AirRunner commented Mar 13, 2026

Summary

Qwen3.5 checkpoints ship with a built-in Multi-Token Prediction head (mtp_num_hidden_layers: 1 in config) that predicts token t+2 from the backbone hidden state at t and the embedding of token t+1. This PR adds support for using it as a native speculative decoding mechanism. No separate draft model needed, at minimal extra compute (1 extra transformer layer).

Changes

  • mlx_lm/generate.py: MTP generation loop with draft/verify and probabilistic acceptance, --mtp CLI flag
  • mlx_lm/models/cache.py: rollback_state slot for conv/SSM snapshot on draft rejection
  • mlx_lm/sample_utils.py: p_draw parameter added to apply_xtc to share the XTC draw across draft and verify
  • mlx_lm/models/qwen3_5.py: MTP head module, self.norm moved to TextModel to expose pre-norm hidden states for MTP, n_confirmed parameter for SSM rollback, sanitize: norm +1 shift now triggered only on raw HF checkpoints (unsanitized conv1d), not on presence of MTP weights, and quant_predicate keeps mtp.fc in full precision
  • mlx_lm/models/qwen3_5_moe.py: MTP checkpoint sanitization for MoE variants + handling both Qwen3.5 and Qwen3.6 (fused gate_up_proj)
  • mlx_lm/server.py: --mtp flag, dynamic MTP/batch switching + fix xtc_special_tokens construction
  • tests/test_mtp.py: 11 unit tests

How it works

Each backbone forward pass returns both logits and pre-norm hidden states. The MTP head fuses pre_fc_norm_hidden(h_t) and pre_fc_norm_embedding(embed(t+1)) via a linear projection, runs one full-attention transformer layer, and produces draft logits through the shared lm_head.

The generation loop verifies drafts by feeding [confirmed_tok, draft_tok] to the backbone with n_confirmed=1. This causes GatedDeltaNet to snapshot its conv/SSM state after the confirmed token. On acceptance, both tokens are emitted. On rejection, the SSM state is rolled back to the snapshot and KV caches are trimmed.

Results (with Qwen3.5-27B 4-bit on M4 Pro)

Metric Standard MTP
Throughput 15.3 tok/s 23.3 tok/s (1.52x)
Acceptance rate 46% avg
Identity test Pass (greedy MTP == standard)

Usage

mlx_lm.generate --model <path> --mtp
mlx_lm.server   --model <path> --mtp

Checkpoint conversion

This requires a checkpoint converted with MTP weights (the default sanitize() previously stripped them). Re-convert from HF with this branch to preserve mtp.* weights.

Important: For it to work on M1, the flag --dtype float16 must be added to the convert command. M1 lacks native BF16 GPU support (MTLDataType.bfloat requires Apple8+), which penalizes mtp.fc kept in full precision by the quant_predicate. Without it, MTP may slow down on M1 despite positive acceptance rates.

Questions for reviewers

  1. sampler is None as greedy signal: I use sampler is None to distinguish greedy from stochastic and apply exact-match vs probabilistic acceptance accordingly. Is this the right signal, or would you prefer an explicit greedy: bool parameter for instance?
  2. Dynamic MTP/batch switching: the server now auto-switches based on self.requests.empty(): MTP for solo requests and BatchGenerator for concurrent ones. Is a best-effort queue check the right approach, or is there a preferred pattern in the server architecture?

Addressed in feat/mtp-batched where GenerationBatch supports MTP natively for B > 1

Future work

DRY refactor + SamplerConfig

A follow-up PR independent of MTP would address:

  1. Code duplication across the now three generator functions:
  • _prefill logic: 3 variants across generate_step, speculative_generate_step, and mtp_generate_step
  • _process_and_sample: almost same pattern in speculative_generate_step and mtp_generate_step
  • quantize_cache_fn = functools.partial(...): same pattern in all three
  1. SamplerConfig: currently mtp_generate_step cannot accept a pre-built sampler= callable and produce correct acceptance logprobs simultaneously. A sampler today returns only a token, but for MTP the acceptance criterion also needs the log-probability distribution the token was drawn from. The fix is a richer sampler interface that returns (token, lp_distribution), allowing both generate_step and mtp_generate_step to share the same interface without passing a dozen individual parameters.

Beyond DRY, SamplerConfig unlocks a potential performance gain: sparse residual sampling.
On rejection at temp > 0, the current implementation samples from max(p_target - p_draft, 0) / Z over the full vocabulary (151K-token for Qwen3.5, 580 µs/call). With top_k > 0, the sampler already computes a top-k partition over the vocabulary, so exposing those indices lets the rejection path work on a K-token support instead.
Without a SamplerConfig, re-running argpartition specifically for the rejection path is slower or equal to the full-vocab path.

Batched MTP

This PR brings MTP for the solo request path only.
However, per-sequence selective rollback (restore SSM state + trim KV only for rejected sequences) is already implemented in AirRunner/mlx-lm · feat/mtp-batched, left out of this PR to keep the diff reviewable.

Test plan

  • Unit tests (11/11 passing) — module existence, cache creation, shapes, pre-norm hidden states, quant predicate, generation identity, end-to-end
  • Manual validation on Qwen3.5-27B, Qwen3.5-0.8B and Qwen3.5-35B-A3B (all 4-bit)

Relates to #872 — cc @janhilgard


Update - probabilistic acceptance and MoE benchmarks

Integrated probabilistic draft acceptance with two cases:

  1. Greedy (sampler=None): exact-match acceptance, mathematically correct for deterministic argmax sampling
  2. Stochastic (temp > 0): min(1, p_target / p_draft): recovers greedy acceptance level at any temperature

Benchmarks on M4 Pro, with 8 diverse prompts:

A reproducible benchmark script is available: bench_mtp.py

Qwen3.5-27B 4-bit

Tok/s Speedup Acceptance
No MTP 15.3 1.00x
MTP, temp=0 24.0 1.57x 46%
MTP, temp=0.6, exact match 22.7 1.49x 43%
MTP, temp=0.6, probabilistic 22.9 1.51x 46%

Qwen3.5-35B-A3B 4-bit

Tok/s Speedup Acceptance
No MTP 85.3 1.00x
MTP, temp=0 87.9 1.04x 46%
MTP, temp=0.6, exact match 84.5 0.98x 44%
MTP, temp=0.6, probabilistic 86.5 1.03x 46%

On M4 Pro MoE speedup is marginal regardless of acceptance rate. MTP benefit scales with baseline decode time, so at 85 tok/s (3B active params) the MTP overhead is proportionally too large to yield meaningful speedup. With probabilistic acceptance, acceptance rates are consistent though with the dense model (~46%).

For reference:

  • @Thump604's MoE results (M2 Ultra, 8-bit, temp=0, exact match): 35B-A3B 1.11x, 122B-A10B 1.09x.
  • @sammcj results (M5 Max, 4-bit, temp=0): 9B +11.3%, 27B +35.5%, 122B +12.4%.
  • @Anionex results (M5 Pro, 4-bit, temp=0): 27B +31.4%, 44.3% acceptance.

@vlbosch
Copy link
Copy Markdown

vlbosch commented Mar 15, 2026

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

@AirRunner
Copy link
Copy Markdown
Author

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

Thanks!

Yes mtp_generate_step() is fully reusable, but each model still needs its own model-side interface.

The Qwen3.5-specific part is MTPDecoderLayer, mtp_forward (produce draft logits), make_mtp_cache and the backbone's __call__ (with n_confirmed for SSM state rollback on hybrid models).

So the speculative-decoding logic lives in one place, and adding a new model is just a matter of exposing the right interface.

For GLM5 specifically, it would certainly be feasible yeah. But I don't think there is even a glm5.py currently.

@Thump604
Copy link
Copy Markdown

Great work on this! We've been using it on M2 Ultra (128GB) with all three Qwen3.5 sizes and it works well.

MoE fix needed

The PR works out of the box for the dense 27B, but MoE models (35B-A3B, 122B-A10B) fail conversion with "768 parameters not in model". The MTP layer's expert weights use unfused per-expert format (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) unlike the backbone which uses pre-fused gate_up_proj. The existing sanitize() in qwen3_5_moe.py only handles backbone expert stacking.

Fix (add to qwen3_5_moe.py sanitize(), after the backbone expert stacking loop):

# Stack per-expert MTP weights into switch_mlp format.
mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0)
num_experts = self.language_model.args.num_experts
for l in range(mtp_num):
    prefix = f"language_model.mtp.layers.{l}.mlp"
    test_key = f"{prefix}.experts.0.gate_proj.weight"
    if test_key in new_weights:
        for n in ["gate_proj", "up_proj", "down_proj"]:
            to_join = [
                new_weights.pop(f"{prefix}.experts.{e}.{n}.weight")
                for e in range(num_experts)
            ]
            new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join)

Also needs import mlx.core as mx at the top of the file.

Full fix on our fork: Thump604/mlx-lm@04a4383

Benchmark results (M2 Ultra, greedy)

Model Baseline MTP Speedup
27B-8bit (dense) 20.6 tok/s 27.1 tok/s 1.32x
35B-A3B-8bit (MoE) 74.4 tok/s 82.3 tok/s 1.11x
122B-A10B-5bit (MoE) 43.0 tok/s 46.7 tok/s 1.09x

Pre-converted models with MTP weights: Thump604/Qwen3.5-27B-MLX-8bit, 35B, 122B

@AirRunner
Copy link
Copy Markdown
Author

@Thump604 Thanks for the report and the fix! I've integrated it in AirRunner/mlx-lm@8d06796 with a credit.

Also, what acceptance rates did you get with MoE? I'm curious if it's somehow correlated to the speedup.

@Thump604
Copy link
Copy Markdown

Thanks for the quick integration!

Here are the acceptance rates derived from our benchmarks (M2 Ultra 128GB, greedy/temp=0):

Model Baseline tok/s MTP tok/s Speedup Implied Accept Rate
27B dense 8-bit 20.6 27.1 1.32x ~32%
35B-A3B MoE 8-bit 74.4 82.3 1.11x ~11%
122B-A10B MoE 5-bit 43.0 46.7 1.09x ~9%

At temp=0.6 (production sampling), 122B drops to 1.05x (~5% acceptance).

So yes — it does correlate with architecture. MoE acceptance rates are significantly lower than dense. My hypothesis: the MTP layer contains a full 256-expert MoE routing step (same expert count as the backbone), but with only a single layer of context depth it struggles to predict the correct expert routing. The dense 27B's MTP layer is a standard transformer layer — much simpler prediction task, much higher acceptance.

The fp16 27B was actually 0.61x (slower) — bandwidth-saturated, the MTP overhead exceeds the savings. 8-bit quantization is the sweet spot where MTP helps most.

@Thump604
Copy link
Copy Markdown

Hey @AirRunner — thanks for integrating the MoE sanitize fix! The PR has merge conflicts with main now though. Would you be able to rebase? Happy to help if needed.

Also, any thoughts on tagging a maintainer for review? This has been open since March 13 with zero maintainer engagement. The implementation is solid (8 tests, code review feedback addressed, MoE fix integrated), just needs someone to look at it.

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented Mar 21, 2026

Hey @Goekdeniz-Guelmez, would you be able to take a look when you get a chance?

Quick summary: 8 unit tests, code review feedback from @janhilgard and @Thump604, rebased on main.
Results: 1.52x token generation on Qwen3.5-27B dense on M4 Pro, validated independently on M2 Ultra across three Qwen3.5 sizes (MoE and dense).

@layer4down
Copy link
Copy Markdown

Subject: Successfully running Qwen3.5-27B locally with workaround

Transparency Note: This comment was drafted with the assistance of an AI assistant to help document the troubleshooting process. All technical details and findings are from actual testing.


Thanks for this PR! I was able to get Qwen3.5-27B working locally with MLX, but encountered an issue that might help others.

The Bug I Was Addressing

When trying to use the model with a client that passes short model IDs, I encountered:

401 Client Error. (Request ID: Root=1-69bfb0a8...)
Repository Not Found for url: https://huggingface.co/api/models/qwen3_5-27b_4bit/revision/main.
Please make sure you specified the correct `repo_id` and `repo_type`.
User Access Token "Claude-flow-ro" is expired

The error message was misleading - it suggested an expired token, but the real issue was a config/weight mismatch described below.

Issue Encountered

The model failed to load with:

ValueError: Missing 15 parameters: 
language_model.mtp.fc.weight,
language_model.mtp.layers.0.input_layernorm.weight,
...

Root Cause

The model's config.json (from mlx-community/Qwen3.5-27B-4bit on HuggingFace) has:

{
  "text_config": {
    "mtp_num_hidden_layers": 1
  }
}

However, the actual .safetensors weights do not contain any MTP parameters. The PR code correctly expects MTP weights when mtp_num_hidden_layers > 0, but this particular model's config claims MTP support that isn't present in the weights.

Workaround

Set mtp_num_hidden_layers to 0 in the model's config:

cat config.json | jq '.text_config.mtp_num_hidden_layers = 0' > config_fixed.json
mv config_fixed.json config.json

Other Configuration Notes

For anyone trying this setup:

  • Context length: Model supports 98K+ context; works with --max-tokens 98304
  • KV cache quantization: Works with MLX_KV_CACHE_QUANT=true environment variable
  • Model path as ID: The server uses the full local path as the model ID in API calls. For example:
    // Request to /v1/chat/completions
    {
      "model": "/path/to/local/models/mlx-community/Qwen3.5-27B-4bit",
      "messages": [...]
    }
    Short names like "Qwen3.5-27B" will trigger a HuggingFace lookup (and fail if the repo doesn't exist or auth is expired).

Suggestion

It might be helpful to add a check/warning when:

  1. mtp_num_hidden_layers > 0 in config
  2. But MTP weights are missing from the loaded model

This would help users identify config/weight mismatches more quickly and avoid confusing auth error messages.

@AirRunner
Copy link
Copy Markdown
Author

@layer4down thanks for the write-up!

You're right, mlx-community/Qwen3.5-27B-4bit was quantized without the MTP head weights, the mtp_num_hidden_layers: 1 in the config is inherited from the original Qwen3.5 config but the MTP parameters were not included when quantizing.

To actually use MTP acceleration, the model needs to be re-quantized including the MTP layers using this branch.

As you suggested I just pushed a fix that raises a clear ValueError instead of the cryptic "Missing N parameters" crash.

@Thump604
Copy link
Copy Markdown

@angeloskath -- this PR has been open 11 days with no maintainer review. AirRunner rebased on 2026-03-21, all conflicts resolved, 8 unit tests passing.

We've been running this in production on M2 Ultra 128GB since day one. Qwen3.5-122B-A10B-VLM-MTP-5bit, 24/7 inference serving coding agents. MTP acceptance rates:

  • 27B dense 8-bit: 1.32x (32% acceptance, best fit)
  • 35B MoE 8-bit: 1.11x (11% acceptance)
  • 122B MoE 5-bit: 1.09x (9% acceptance)

MoE acceptance rates are lower because a single MTP layer can't predict expert routing well. Still a net win for the latency-sensitive use case.

The MoE sanitize fix (commit 8d06796) is essential for Qwen3.5 MoE models -- without it, 768 MTP parameters are silently missing. We've also published pre-converted VLM+MTP models on HuggingFace that depend on this code path.

Would be great to get this reviewed and merged so the community models work out of the box.

@cresseelia
Copy link
Copy Markdown

cresseelia commented Mar 29, 2026

Can we at the reviewer again? it's an important update for qwen3.5

@Thump604
Copy link
Copy Markdown

@angeloskath @awni — this PR has been open 17 days with no maintainer review or feedback. Multiple community members have asked for review (AirRunner, ourselves, cresseelia).

Is there a concern with the approach, scope, or implementation that's blocking review? We're happy to help address any issues — split the PR, rework the API surface, add tests, whatever is needed.

We're running this in production on 122B and have validated it across three Qwen3.5 model sizes. The community is actively hitting the config/weight mismatch that AirRunner already fixed in this branch (layer4down's report above). Without this merged, users have to manually patch config.json to use MTP on Qwen3.5 models.

If the PR needs changes or a different direction, we'd rather know than wait. Let us know how we can help move this forward.

@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

as you can see 6 files have been changes/added alongside 700 lines of added code. This is a PR that has big changes int he codebase itself. Reviewing and (correctly) implementing it will take time. 17 days not long enough. My full weight fine-tuning PR took multiple weeks to be merged. Just keep it open, update it and please be patient. Adding completely new features will take long.

@janhilgard
Copy link
Copy Markdown

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented Apr 1, 2026

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

@janhilgard I'm not sure splitting would actually help the review here actually?

The PRs you suggest wouldn't be reviewable in isolation, because the architecture changes only make sense in the context of how mtp_generate_step uses them. Also the changes in generate would be dead code until the other PR lands, so one would need to review both PRs together anyways.

(Also, 183 of the 683 added lines are just unit tests).

That said, I'm open to whatever helps, happy to reorganize if it does :).

@Thump604
Copy link
Copy Markdown

Thump604 commented Apr 1, 2026

@angeloskath @awni — this PR has been open 20+ days with no maintainer review. It is the foundation for MTP speculative decoding on Qwen3.5 models, which several of us are using in production. My PR #1085 (probabilistic acceptance, 2.3x throughput on 122B) builds directly on top of it.

AirRunner's implementation is solid: 8 tests, 80.6% acceptance on M4 Pro. Is there a concern about scope or approach blocking review?

@gyzerok
Copy link
Copy Markdown

gyzerok commented Apr 1, 2026

@Thump604 can you stop pinging people? The more annoying you are the less likely anyone is going to respond.

@janhilgard
Copy link
Copy Markdown

Great work — I've been running MTP on Qwen3.5 MoE models in production (M3 Ultra, 256 GB) and wanted to share findings that might explain the low MoE acceptance rates.

BF16 MTP weights are critical for MoE acceptance

Your quant_predicate excludes only mtp.fc:

if path.endswith("mtp.fc"):
    return False

But the MTP transformer layer (attention, MLP, norms) still gets quantized. We found that quantized MTP weights give near-0% acceptance on MoE models — the quantization error compounds through the expert routing prediction.

Fix: exclude ALL MTP weights from quantization:

if "mtp." in path:
    return False

Our MoE results with BF16 MTP weights

Model Quantization MTP weights Acceptance Speedup
35B-A3B 4-bit BF16 79-85% 1.18x
122B-A10B 4-bit BF16 77-78% 1.12x
35B-A3B 4-bit dequantized 4→BF16 ~0%

vs your MoE benchmarks (quantized MTP weights):

Model MTP weights Implied acceptance Speedup
35B-A3B 8-bit quantized ~11% 1.11x
122B-A10B 5-bit quantized ~5% 1.09x

The difference is stark: BF16 MTP weights → 79-85% acceptance, quantized → 5-11%.

Batch auto-skip

Your PR sets is_batchable = False when MTP is active. In our vllm-mlx integration (#245 on waybarrios/vllm-mlx) we auto-skip MTP when batch_size > 1:

if len(active_batch) > 1:
    # Skip MTP, fall back to standard generation
    return _orig_step(input_tokens, cache)

This gives the best of both worlds:

  • 1 request: MTP active → 86 tok/s (1.18x)
  • 8 requests: MTP skipped → 307 tok/s (full batching throughput)

Instead of disabling batching entirely, you could dynamically switch.

Weight extraction

We extract BF16 MTP weights from the original HF model (not the quantized MLX model) with a dedicated script. See vllm-mlx PR #245 for the add_mtp_weights_qwen35.py script that:

  • Downloads only MTP-containing shards (not entire model)
  • Stacks per-expert weights into SwitchLinear format
  • Applies RMSNorm +1.0 shift
  • Outputs native BF16

Happy to collaborate on getting BF16 MTP weights into the standard conversion pipeline.

@Thump604
Copy link
Copy Markdown

Thump604 commented Apr 2, 2026

I tested your BF16 MTP finding on our models. Sharing the data since it tells a different story on 5-bit and 8-bit backbones.

I extracted fresh BF16 MTP weights from the original HF models (not dequantized from quantized), applied the RMSNorm +1.0 shift, stacked MoE experts into SwitchLinear format, and re-quantized to match the backbone (5-bit gs=64 for 122B, 4-bit gs=64 for 4B). This matches the process you describe in your extraction script.

Results (probabilistic acceptance, temp=0.6):

Model Backbone Original quantized MTP BF16-source re-quantized MTP
4B dense 4-bit gs=64 44.9%, 91.8 tok/s 43.8%, 86.5 tok/s
122B MoE 5-bit gs=64 47.3%, 21.5 tok/s 47.3%, 21.0 tok/s

No measurable difference. Re-quantizing MTP from the BF16 source produces the same acceptance as the original quantized weights on these models.

I also tested with fully unquantized BF16 MTP (no re-quantization, just raw BF16 + norm shift). This gave 0% acceptance across all models. The BF16 MTP forward pass produces a different logit distribution than the quantized backbone expects. Once I re-quantize to match the backbone, the acceptance rate converges to the same ~46%.

Your 79-85% acceptance at 4-bit is significantly higher than what I see. A few questions:

  • Are you running the MTP layer entirely in BF16 (unquantized), or does your script quantize it to match the backbone?
  • Which mlx-lm generate path are you using? Our probabilistic acceptance is from PR feat: probabilistic MTP acceptance (speculative sampling) #1085 (min(1, p_target/p_draft)). Exact match at temp=0.6 gives ~5%.
  • Are your numbers from greedy (temp=0) or sampled (temp=0.6)?

Our acceptance ceiling appears to be ~47% with probabilistic sampling regardless of how the MTP weights are prepared, as long as they match the backbone's quantization. If you are getting 79-85%, there may be a difference in the generation loop or sampling strategy that accounts for the gap.

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 6, 2026

@AirRunner You may take a look at this repo

This is indeed interesting work. A few things I've noted:

  1. depth>1

Following up on lawcontinue's question: MTPLX chains k sequential MTP calls, each consuming the hidden state of the previous one, then verifies all k drafts in a single backbone forward pass.

As I noted before, Qwen3.5's MTP head was likely trained for 1-step prediction. At depth>1 it would receive its own output hidden states instead of the backbone's, which would presumably create a progressive distribution drift, especially with depth>=3.

  1. Correctors

This seems to be a great research contribution, they compensate the depth>1 drift with affine transforms at each depth level (offline-fitted). It could make depth=3 viable. The code is in the repo, but I couldn't find corrector weights in any of their model checkpoints though. It might be a work in progress.

  1. Adaptive depth

There is an ExpectedValueDepthPolicy policy that computes an expected-value gate before each draft step, weighing the probability of acceptance against the measured cost of the extra forward pass. Promising, however the cost parameters seem to be hardware-specific and exposed as CLI flags with hardcoded defaults.

  1. Metal kernels
  • gdn_norm_gate_stage fuses RMSNorm + SwiGLU for the GDN tail. I prototyped it via mx.fast.metal_kernel and tested it at depth=2, but couldn't get a measurable gain. head_v_dim=128 might be too small to matter against the backbone bandwidth cost.

  • verify_qmv replaces the stock mx.quantized_matmul path for small M (the verify window size), reusing dequantized weight loads across rows.

  1. Model checkpoints

They published 4 checkpoints on HuggingFace for Qwen3.6-27B and Qwen3.5-4B, with AWQ-calibrated MTP weights.


Implementing MTP depth

So inspired from MTPLX I implemented a --mtp-depth on a separate feat/mtp-depth branch.

The interface is generic, any model implementing mtp_forward(..., return_hidden=True) benefits automatically. (For qwen3_5, on rejection at position j, GDN layers restore the per-position SSM/conv snapshot from rollback_states[j] and KV layers trim by mtp_depth - j).

Benchmarks

Here are some benchmarks on M4 Pro with Qwen3.5-27B 4-bit (8 diverse prompts, 3 runs, script here).

Condition Decode tok/s Speedup Acceptance
baseline (temp=0) 15.67 1.00x -
mtp_depth=1 (temp=0) 24.02 1.53x 45.8%
mtp_depth=2 (temp=0) 22.92 1.46x 59.8%
baseline (temp=0.6) 15.65 1.00x -
mtp_depth=1 (temp=0.6) 23.93 1.53x 45.6%
mtp_depth=2 (temp=0.6) 22.77 1.45x 59.9%

→ Acceptance at depth=2 rises to 60% as expected, but the structural overhead exceeds the gain on this hardware so the speedup is worse. Also, as said before, the MTP head was likely trained for 1-step prediction, so running it on its own output hidden states at depth>1 presumably adds distribution drift on top of that overhead.

On MTPLX's 2.24x

I don't have data on configs where depth=2 would be profitable, and without MTPLX's type of correctors the practical benefit is unclear.

The key bottleneck at depth>1 is the verify pass. The backbone runs on k+1 tokens (M=k+1 in Metal terms), and stock MLX's qmv kernel is tuned for M=1 (normal single-token decode) and becomes increasingly inefficient as M grows.

At depth=1 (M=2) the overhead is manageable (indeed I got 1.53x). But at depth=2 (M=3) it already drops to 1.46x on my benchmark, and at depth=3 (M=4) the verify cost likely cancels out the acceptance gain entirely. This is exactly what MTPLX's README says: without a kernel tuned for small M, the verify overhead grows faster than the acceptance benefit as depth increases.

MTPLX's custom MLX fork (mlx-mtplx-0.31.2-qmm) retunes qmv for M=3..6 (4-simdgroup, unroll_count(4)), which is what would make depth=3 viable. The fork doesn't appear to be public though, so not currently reproducible.

@youssofal, is that a correct reading? Also do you plan to contribute to ml-explore/mlx with this qmv retuning?

@Goekdeniz-Guelmez, curious what you think about this.

@AirRunner
Copy link
Copy Markdown
Author

The native MTP approach here (zero extra model RAM, built-in heads) is a clean win for single-machine Qwen3.5 serving. The 80.6% acceptance rate at 1.52x on M4 Pro matches what we'd expect from 1-layer MTP drafters.

One angle not covered yet: MTP and pipeline parallelism are orthogonal. We're building Hippo — a distributed MLX framework that splits models across Apple Silicon nodes for cases where a single machine can't hold the model. MTP speeds up per-node decode while pipeline parallelism scales the model size ceiling. In theory a 27B with MTP across 2 Mac Minis gets both the 1.5x decode speedup and the memory ceiling lift.

The tricky part is SSM state sync across pipeline boundaries during draft rejection rollback — have you thought about that?

@lawcontinue The rollback in this PR works exactly the way @janhilgard describes for vllm-mlx. ArraysCache carries a rollback_state: Optional[tuple] per GDN layer. During the verify pass, GatedDeltaNet snapshots (conv_state, ssm_state) after the confirmed token, then on rejection _rollback_draft() restores it. Attention layers just trim their KV cache by 1. Each layer restores independently.

For your pipeline parallelism case, you would apply the same principle: each node holds its own snapshot for the layers it owns and restores it independently on rejection. The snapshot is a fixed-size array and never needs to cross node boundaries. The only thing that needs to propagate across the pipeline on rejection is the rejection signal itself and the KV trim count, so rollback coordination reduces to a single broadcast rather than any state transfer.

@AirRunner AirRunner changed the title feat: native MTP speculative decoding for qwen3_5_moe (Qwen3.5-3.6) Native MTP speculative decoding (Qwen3.5/3.6 reference implementation) May 6, 2026
@youssofal
Copy link
Copy Markdown

@AirRunner Thanks for the careful read. Yes, that is generally correct

The main clarification I’d make is that depth > 1 is not the win by itself. Once you verify multiple drafts, the target pass becomes a small-M quantized matmul problem: M = depth + 1 for the verify window. So depth=1 means M=2, depth=2 means M=3, depth=3 means M=4, etc.

That cost curve is very hardware/runtime dependent. On the current MTPLX Qwen3.6-27B M5 Max speed lane, depth=3 / verify M=4 is the best recorded setting. But that should not be read as universal. On older or more verify-bound machines, especially where verify latency is high, the extra target rows can cost more than the extra accepted drafts earn. I have seen user reports on M1 Max devices where shallower depth (even depth=1) is the better practical setting. MTP can improve acceptance while still lose on wall-clock.

That is why the small-M verify path matters so much. Stock MLX is excellent for normal single token decode, but speculative verification changes the requirments and the acceptance/verify-time ratio becomes this odd balancing game. Without improving that small-M verify path, deeper MTP can look good in acceptance numbers (My MLX acceptance numbers are beating VLLM MTP numbers) and still be the wrong setting for a given machine (Worse chips, worse verify time).

One other clarification: the corrector/adaptive-depth code in MTPLX is research infrastructure. It is not the active reason the public 2.24x path works, and there are no corrector weights shipped in the current public checkpoints. The current speed path is mostly high-quality/native MTP proposals, committed MTP history, the draft-only LM head, capture-commit / linear-GDN runtime work, and the small-M verify/kernel stack.

Current stats: D3 acceptance = [100%, 97.96%, 93.88%], with 3 corrections over 49 verify calls.

On the MLX fork: MTPLX itself is open source, including the standalone/probe kernels and runtime code. The specific MLX-source qmv retuning is currently a local fork/patch, not yet published as a clean upstreamable branch. I’m open to extracting that into a minimal ml-explore/mlx PR if maintainers are interested.

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 6, 2026

@youssofal Thanks for the clarifications.

First the small-M verify bottleneck matches what I observed on our mlx-lm implementation: depth=1 (M=2) gives 1.53x (24 tok/s), depth=2 drops to 1.46x (22.9 tok/s), consistent with verify cost growing faster than acceptance gain as M increases.

With your D3 acceptance numbers ([100%, 97.96%, 93.88%]), it's about 3.9 tokens per cycle vs 1.46 at depth=1, which would be a significant gain if the verify cost stays tractable, and that seems to be exactly where the small-M kernel makes the difference.

To verify this I ran mtplx run --profile performance-cold --depth N on v0.1.5 with Youssofal--Qwen3.6-27B-MTPLX-Optimized-Speed (M4 Pro):

mode tok/s vs AR
AR 14.06 1.00x
MTP D1 23.22 1.65x
MTP D2 21.98 1.56x
MTP D3 18.75 1.33x

D3 drops back despite higher acceptance, which points to verify cost growing faster than the acceptance gain at M=4. It looks like without the retuned kernel, the verify overhead absorbs most of the gain at depth>2.

Looking at the distributed package, performance-cold default only sets the lazy/batching env flags, the native MLP kernels are opt-in, and native_mlp.py explicitly says "prior probes showed the first native rowwise primitive is slower than stock." The MLX-source qmv retuning doesn't seem to be in the distributed package.

If I'm reading this correctly, and with the right hardware, approaching 2.24x would mainly require both depth>1 and the custom kernel together, right?
If so, if you don't mind, opening a PR to ml-explore/mlx with the qmv retuning for small M could be really interesting! Or publishing a fork so we can build on it.

Either way, thanks again for your work!

jundot added a commit to jundot/omlx that referenced this pull request May 6, 2026
Wires mlx-lm PR 990 + PR 15 as runtime monkey-patches inside
GenerationBatch. oQ preserves mtp.* via a new -mtp suffix; mlx-vlm
sanitize is patched to keep MTP norm shift + language_model prefix
so quantized VLM checkpoints load with correct per-tensor bits.

Source PRs:
- mlx-lm#990 (Qwen3.5/3.6 MTP) by AirRunner
  ml-explore/mlx-lm#990
- mlx-lm fork PR #15 (DeepSeek-V4 MTP) by 0xClandestine
  Blaizzy/mlx-lm#15
@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

@AirRunner your explanation sounds good, and a PR to mlx for the small-M qmv retuning feature would be a good addition. What do you think @angeloskath @zcbenz @jagrit06 ?

@lawcontinue
Copy link
Copy Markdown

Thanks for the detailed breakdown @janhilgard @AirRunner — the local-snapshot + rejection-signal-only approach is cleaner than I expected. ~4KB per layer is negligible next to the activation tensors we're already moving across Thunderbolt.

For the pipeline case: the main cost isn't the rollback itself, it's the pipeline bubble when a rejected draft wastes a full forward pass across N nodes. Curious if either of you have measured how that compares to single-node rejection overhead at depth=1 — in our setup the inter-node latency (~0.8ms Thunderbolt) is small enough that the bubble is dominated by compute, but I'd expect it to hurt more at depth>1 where rejection is more frequent.

AirRunner added 2 commits May 7, 2026 01:40
- Remove z.item() sync: z stays in the MLX graph and is evaluated once alongside categorical(), reducing Metal round-trips from 2 to 1.
- Replace if z > 1e-8 guard with mx.where(z > 0, residual, p_target): when the residual mass is zero, sample from p_target instead of keeping verify_pred (argmax). Matches Leviathan et al. 2022 §2.3.
Replace sampler= callable with explicit sampling params (temp, top_p, top_k, min_p, xtc_*) so mtp_generate_step can compute temperature-adjusted lp_accept for correct probabilistic acceptance at temp > 0.

- Extract make_sampler_chain from make_sampler (DRY); mtp uses it directly to build the filter chain without a pre-assembled sampler.
- Compute lp_accept from the filtered+scaled distribution so it matches the distribution the token was drawn from.
- Share the XTC boolean draw across draft and verify steps via xtc_cell, so both steps apply the same XTC mask.
- Draw acceptance coin as mx.random.uniform(), evaluated in parallel with the verify forward pass (amortized Metal dispatch, consistent with mx.random.seed()).
- Fix _xtc_special_tokens: use tokenizer.eos_token_ids (plural) and concatenate properly instead of mixing int and list.
- Update tests: remove sampler= from MTP tests, add top_k variant, extract _collect_rejection_tokens/_assert_residual_varies helpers.
@JJJYmmm
Copy link
Copy Markdown
Contributor

JJJYmmm commented May 7, 2026

Hi @AirRunner, great work!

I was comparing this PR with my branch that also prefills the native mtp cache during prompt prefill:
https://github.com/JJJYmmm/mlx-lm/tree/native-mtp/prefill

One question: is there a reason the current implementation only starts populating the mtp cache after decode begins, instead of prefilling it together with the prompt? (mtp prompt prefill is more aligned with training, and vllm/sglang also do it)

In my tests, prompt-prefilling the mtp cache gives a noticeably higher n=1 acceptance rate under the same setup (10 sampled GSM8K prompt, max_tokens=128, temp=0, M5 chip):

  • Qwen3.5-9B 4bit: 86.0% vs 82.4%
  • Qwen3.5-4B bf16: 87.1% vs 83.1%

9B 4bit

impl setting tok/s speedup accept
current baseline - 24.20 1.00x -
current native MTP n=1 36.77 1.52x 590/686 = 86.0%
current native MTP n=2 33.03 1.36x 767/1018 = 75.3%
current native MTP n=3 27.42 1.13x 845/1282 = 65.9%
current native MTP n=5 19.38 0.80x 912/1806 = 50.5%
PR990 baseline - 23.45 1.00x -
PR990 MTP n=1 35.44 1.51x 576/699 = 82.4%

4B bf16

impl setting tok/s speedup accept
current baseline - 12.36 1.00x -
current native MTP n=1 13.43 1.09x 595/683 = 87.1%
current native MTP n=2 15.19 1.23x 779/991 = 78.6%
current native MTP n=3 14.11 1.14x 856/1261 = 67.9%
current native MTP n=5 11.27 0.91x 908/1836 = 49.5%
PR990 baseline - 11.65 1.00x -
PR990 MTP n=1 12.40 1.06x 580/698 = 83.1%

Current Per-Position Accept

model n p1 p2 p3 p4 p5
9B 4bit 1 86.0% - - - -
9B 4bit 2 83.6% 67.1% - - -
9B 4bit 3 83.5% 64.9% 49.1% - -
9B 4bit 5 82.5% 64.0% 47.6% 33.9% 23.4%
4B bf16 1 87.1% - - - -
4B bf16 2 86.9% 70.2% - - -
4B bf16 3 87.5% 65.6% 50.4% - -
4B bf16 5 84.4% 62.6% 45.7% 31.2% 22.6%

AirRunner added 2 commits May 7, 2026 16:29
Previously _prefill only populated the backbone cache, leaving the MTP KVCache cold at the start of decode. The MTP head was trained with full prefix context, so starting from an empty cache is misaligned with training.

Now each prefill chunk passes return_hidden=True and immediately calls mtp_forward(hidden, y[1:n+1], mtp_cache). The hidden tensor is transient: consumed within the same iteration before mx.clear_cache().
@AirRunner
Copy link
Copy Markdown
Author

@JJJYmmm Thanks, you're right! To simplify _prefill I constrained the last backbone forward pass to exactly 1 token so that the hidden state [1, N, H] kept alive by return_hidden=True is minimal.
From there it was a short step to not prefilling the MTP cache at all during prompt processing. The MTP cache warms up naturally over the first few decode steps, but I hadn't measured the difference starting from a cold cache. So thanks for your numbers!

It's now fixed, model.mtp_forward() is called immediately (discarding the logits).

It's worth noting though that in multi-turn usage the MTP cache from the previous turn is already populated, so the cold-start effect would only apply to the very first turn. Also, at temp>0 with probabilistic acceptance, the criterion would partially compensates for cold-cache predictions, so the real-world delta will likely be small.
That said, prefilling is still the correct behavior.

VRAM considerations

This adds a bit of overhead, but the [1, N, H] hidden tensor is transient, as it's consumed by mtp_forward in the same iteration and freed before the next chunk.

For Qwen3.5 it just adds about 4 KB/token permanent VRAM cost for the MTP KV cache (4 KV heads × 256 head_dim × BF16), so about 40 MB for a 10K-token prompt.

generate_step calls mx.clear_cache() every 256 tokens to bound the Metal allocator's free list.

Introduce _CACHE_CLEAR_INTERVAL = 256 shared by both generate_step and mtp_generate_step to add the equivalent cache-clearing logic to the MTP decode loop.  The block-based counter (ntoks // _CACHE_CLEAR_INTERVAL) handles MTP iterations that could emit multiple tokens at once, where a '% interval == 0' check could skip a boundary.
@atelepov
Copy link
Copy Markdown

atelepov commented May 8, 2026

@AirRunner
I tested it on M1 MAX 32Gb, but there's no acceleration. It's actually slowing down.
Could you tell me if this is a limitation of M1 MAX 32Gb specifically? Or are some additional parameters incorrectly specified?

Configuration
M1 MAX 32Gb

Convert

uv run python -m mlx_lm convert --hf-path Qwen/Qwen3.6-27B \
  --mlx-path Qwen3.6-27B-mtp \
  -q --q-mode affine --q-bits 4 --q-group-size 64

no mtp

uv run python -m mlx_lm generate --model Qwen3.6-27B-mtp \
  --prompt "Write a quicksort in Python." --max-tokens 100

result

Prompt: 17 tokens, 5.002 tokens-per-sec
Generation: 100 tokens, 19.447 tokens-per-sec
Peak memory: 15.715 GB

mtp

uv run python -m mlx_lm generate --model Qwen3.6-27B-mtp \  
  --prompt "Write a quicksort in Python." --max-tokens 100 --mtp

result

Prompt: 17 tokens, 5.552 tokens-per-sec
Generation: 100 tokens, 17.553 tokens-per-sec
Peak memory: 15.814 GB

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 8, 2026

@atelepov Your setup looks correct. Are the MTP weights present in the quantized version?
M1 Max shouldn't be a limitation, as we tested it M2 Ultra, M4 Pro, M5 Pro/Max, etc.

It's known that on MoE there is no much gain, especially on the 35B-A3B where is could even drop 1 or 2%.
But the 27B should definitely benefit from MTP.

It might simply be the benchmark length (might be too much variance at 100 tokens).
Though I just tested on M4 Pro with the same model and prompt with --max-tokens 100, I still get +51.4%.

Could you retry with --max-tokens 256, or even more (like 2048)?

@s-n-t
Copy link
Copy Markdown

s-n-t commented May 8, 2026

For M1/M2 you always want "--dtype float16" when quantizing as bfloat16 goes through a software path I think? - I still don't see any improvement at 4-bit though:, but 8-bit shows a decent boost.

python3 -m mlx_lm convert --hf-path Qwen/Qwen3.6-27B \
  --mlx-path Qwen3.6-27B-mtp-8-bit \
  -q --q-mode affine --q-bits 8 --q-group-size 64 --dtype float16
python3 -m mlx_lm convert --hf-path Qwen/Qwen3.6-27B \
  --mlx-path Qwen3.6-27B-mtp-4-bit \
  -q --q-mode affine --q-bits 4 --q-group-size 64 --dtype float16
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
  ==========
Prompt: 17 tokens, 32.623 tokens-per-sec
Generation: 1024 tokens, 20.213 tokens-per-sec
Peak memory: 15.900 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...
 ==========
Prompt: 17 tokens, 30.828 tokens-per-sec
Generation: 1024 tokens, 20.601 tokens-per-sec
Peak memory: 15.737 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
==========
Prompt: 17 tokens, 23.865 tokens-per-sec
Generation: 1024 tokens, 14.406 tokens-per-sec
Peak memory: 29.513 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...  
==========
Prompt: 17 tokens, 22.529 tokens-per-sec
Generation: 1024 tokens, 11.346 tokens-per-sec
Peak memory: 29.351 GB

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 8, 2026

@s-n-t Interesting findings. So collecting different data points across the comments of this PR, we have:

Hardware Model dtype Baseline MTP Delta
M1 Max 27B 4-bit fp16 20.6 tok/s 20.2 tok/s -2%
M1 Max 27B 8-bit fp16 11.3 tok/s 14.4 tok/s +27%
M2 Ultra 27B 8-bit bf16 20.6 tok/s 27.1 tok/s +32%
M4 Pro 27B 4-bit bf16 15.6 tok/s 23.7 tok/s +51%
M5 Max 27B 4-bit bf16 32 tok/s 44 tok/s +35%

Important note: --dtype float16 explicit override seems to be important for M1, otherwise the default bf16 goes through a software path.
mtp.fc is kept in full precision by our quant_predicate, so if BF16 ops go through a software path on M1, that layer adds disproportionate overhead.


The M1 Max result is the outlier. The speedup differences across chips could be better explained by the β+δ framework from my reply to Anionex: speedup = (1+p) / (β+δ), where p is the per-round acceptance probability, β is the 2-token backbone overhead, and δ is the MTP head cost, both relative to a single baseline step.

Also from what I understand, there is no dual-datapath execution (FP32/BF16 and int4 ops serialized) or dedicated matrix accelerators on M1. M3 introduced both. So the M1 GPU generation likely adds disproportionate compute overhead on the MTP head forward pass (short sequence, more compute-bound than the memory-bound backbone).

With p≈0.85 (from α=0.46 via α=p/(1+p)), breakeven is at β+δ = 1+p ≈ 1.85. On M4 Pro β+δ = 1.190, well below that. On M1 Max, the MTP head forward pass likely pushes β+δ closer to 1.85, which would explain the near-zero result. Running this bench script would give a measured β+δ for M1 Max.

@s-n-t
Copy link
Copy Markdown

s-n-t commented May 8, 2026

Without the "--dtype float16" on the same hardware (confirms @atelepov numbers):

python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
==========
Prompt: 17 tokens, 24.655 tokens-per-sec
Generation: 1024 tokens, 15.337 tokens-per-sec
Peak memory: 15.895 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...
==========
Prompt: 17 tokens, 23.971 tokens-per-sec
Generation: 1024 tokens, 19.700 tokens-per-sec
Peak memory: 15.737 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
==========
Prompt: 17 tokens, 19.169 tokens-per-sec
Generation: 1024 tokens, 14.032 tokens-per-sec
Peak memory: 29.513 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...
==========
Prompt: 17 tokens, 18.781 tokens-per-sec
Generation: 1024 tokens, 11.231 tokens-per-sec
Peak memory: 29.351 GB

@atelepov
Copy link
Copy Markdown

atelepov commented May 8, 2026

@AirRunner
Thank you very much.
Specifying the --dtype float16 parameter during conversion increases response generation.

--dtype float16
MTP

--max-tokens 256 --temp 0
Prompt: 17 tokens, 5.502 tokens-per-sec
Generation: 256 tokens, 21.729 tokens-per-sec
Peak memory: 15.833 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.758 tokens-per-sec
Generation: 1628 tokens, 22.117 tokens-per-sec
Peak memory: 15.945 GB

NO MTP

--max-tokens 256 --temp 0
Prompt: 17 tokens, 5.501 tokens-per-sec
Generation: 256 tokens, 20.285 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.690 tokens-per-sec
Generation: 1628 tokens, 19.700 tokens-per-sec
Peak memory: 15.777 GB

--dtype DEFAULT

MTP

--max-tokens 256
Prompt: 17 tokens, 4.490 tokens-per-sec
Generation: 256 tokens, 18.103 tokens-per-sec
Peak memory: 15.834 GB

--max-tokens 256 --temp 0
Prompt: 17 tokens, 5.511 tokens-per-sec
Generation: 256 tokens, 17.529 tokens-per-sec
Peak memory: 15.833 GB

--max-tokens 256 --temp 0.6
Prompt: 17 tokens, 5.454 tokens-per-sec
Generation: 256 tokens, 17.555 tokens-per-sec
Peak memory: 15.833 GB

--max-tokens 2048
Prompt: 17 tokens, 6.118 tokens-per-sec
Generation: 1806 tokens, 15.917 tokens-per-sec
Peak memory: 15.957 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.797 tokens-per-sec
Generation: 1806 tokens, 15.092 tokens-per-sec
Peak memory: 15.961 GB

NO MTP

--max-tokens 256
Prompt: 17 tokens, 5.449 tokens-per-sec
Generation: 256 tokens, 11.092 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 256 --temp 0
Prompt: 17 tokens, 4.499 tokens-per-sec
Generation: 256 tokens, 19.318 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 256 --temp 0.6
Prompt: 17 tokens, 5.136 tokens-per-sec
Generation: 256 tokens, 19.694 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 2048
Prompt: 17 tokens, 5.194 tokens-per-sec
Generation: 1830 tokens, 12.277 tokens-per-sec
Peak memory: 15.796 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.472 tokens-per-sec
Generation: 1830 tokens, 17.354 tokens-per-sec
Peak memory: 15.797 GB

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 8, 2026

@s-n-t Ok then your -22% (bf16) vs -2% (fp16) gap is almost entirely explained by BF16 emulation on M1.

M1 has no native BF16 GPU support, MTLDataType.bfloat requires Apple8+ (M1 is Apple7), confirmed by hw.optional.arm.FEAT_BF16: 0 on M1.
The quant_predicate keeps mtp.fc in full precision, so that layer runs through a software path on M1, adding disproportionate overhead to every MTP step.

So for M1 users: --dtype float16 is strongly recommended.

The 8-bit case is much less affected by dtype (+25% vs +27%) because the baseline is slower. Same absolute BF16 overhead on mtp.fc, smaller fraction of a slower step.

With fp16, the residual -2% on 4-bit likely reflects the β+δ compute overhead on M1's GPU architecture (see here).


@atelepov Nice, around +10% on M1 Max seems to be the expected range on this architecture.

@heykb
Copy link
Copy Markdown

heykb commented May 9, 2026

M4 pro 48G. Qwen3.6-27 4bit
============================================================================
SUMMARY  (decode tok/s, first token excluded)
============================================================================
Condition                     Pooled        Mean+-SD    Accept     Prefill
----------------------------------------------------------------------------
  baseline  temp=0             15.26    15.3+-0.2        --        30.0
  MTP       temp=0             22.02    22.1+-1.3       46.9%      31.2
----------------------------------------------------------------------------

Speedup (pooled decode tok/s, MTP vs baseline at matching temperature):
  MTP       temp=0        1.444x  (alpha=0.469)  beta+delta=1.3038

Model  : /Users/a1-6/models/Qwen3.6-27B-mtp
Config : max_tokens=256, runs=3, warmup=1, prompts=8

@TomLucidor
Copy link
Copy Markdown

TomLucidor commented May 9, 2026

@heykb did MTP manage to accelerate prefill by accident?

Declare u = mx.random.uniform() immediately before its first use (mx.eval) rather than before the unrelated _step_backbone call.
@crazyi
Copy link
Copy Markdown

crazyi commented May 10, 2026

@s-n-t Ok then your -22% (bf16) vs -2% (fp16) gap is almost entirely explained by BF16 emulation on M1.

M1 has no native BF16 GPU support, MTLDataType.bfloat requires Apple8+ (M1 is Apple7), confirmed by hw.optional.arm.FEAT_BF16: 0 on M1. The quant_predicate keeps mtp.fc in full precision, so that layer runs through a software path on M1, adding disproportionate overhead to every MTP step.

So for M1 users: --dtype float16 is strongly recommended.

The 8-bit case is much less affected by dtype (+25% vs +27%) because the baseline is slower. Same absolute BF16 overhead on mtp.fc, smaller fraction of a slower step.

With fp16, the residual -2% on 4-bit likely reflects the β+δ compute overhead on M1's GPU architecture (see here).

@atelepov Nice, around +10% on M1 Max seems to be the expected range on this architecture.

I found that hw.optional.arm.FEAT_BF16: 1 is supported on the M2 chip. If the M2 chip supports BF16, why does converting to FP16 still result in performance improvements?
jundot/omlx#604
Why the M2 is more advanced that it seemed

@deepsweet
Copy link
Copy Markdown

If the M2 chip supports BF16, why does converting to FP16 still result in performance improvements?

See my detailed benchmark and conclusions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.