Skip to content

[QUARK-403] Add MiniMax-2.1 support#237

Open
thpereir wants to merge 1 commit intoROCm:mainfrom
thpereir:thpereir/minimax21_mxfp4
Open

[QUARK-403] Add MiniMax-2.1 support#237
thpereir wants to merge 1 commit intoROCm:mainfrom
thpereir:thpereir/minimax21_mxfp4

Conversation

@thpereir
Copy link
Contributor

Motivation

Add support for MiniMax2.1 into ATOM

Technical Details

MiniMax2.1 uses sigmoid on the Expert selection instead of just topk

Test Plan

Run server:

python -m atom.entrypoints.openai_server     --model /scratch/models/MiniMax-M2.1-MXFP4/     -tp 4     --kv_cache_dtype fp8     --max-model-len 32768     --gpu-memory-utilization 0.90

Run lm-eval:

lm_eval --model local-completions --model_args model=/scratch/models/MiniMax-M2.1-MXFP4/,base_url=http://localhost:8000/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False --tasks gsm8k --num_fewshot 5

Test Result

Results are below what we obtained on vllm so we need to debug further

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.2972 ± 0.0126
strict-match 5 exact_match 0.2585 ± 0.0121

Submission Checklist

@thpereir thpereir force-pushed the thpereir/minimax21_mxfp4 branch from 0c962ff to e2a847f Compare February 26, 2026 16:58
@thpereir
Copy link
Contributor Author

@valarLip can you please take a look on this one?

@thpereir thpereir force-pushed the thpereir/minimax21_mxfp4 branch from e2a847f to 881b16f Compare March 12, 2026 22:18
@thpereir
Copy link
Contributor Author

Found a big issue with this model:
q_norm and k_norm were initialized with per-TP-rank sizes ([768] and [128] at TP=8) instead of the full checkpoint sizes ([6144] and [1024]). The weight loader silently skipped loading them due to the shape mismatch, leaving both norms with all-ones weights regardless of calibration — effectively disabling qk_norm entirely across all 62 layers.

Fix: weights are now sized to match the checkpoint (total_num_heads * head_dim, replicated across TP ranks). The forward pass computes global variance via an all-reduce of per-rank sum-of-squares, then each rank multiplies by its contiguous head-slice of the weight vector.

After the fix this is gsm8k result

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8878 ± 0.0087
strict-match 5 exact_match 0.8863 ± 0.0087

@thpereir thpereir force-pushed the thpereir/minimax21_mxfp4 branch 2 times, most recently from 88a25cd to e11e392 Compare March 12, 2026 22:39
@haoyangli0109
Copy link
Contributor

haoyangli0109 commented Mar 13, 2026

Found a big issue with this model: q_norm and k_norm were initialized with per-TP-rank sizes ([768] and [128] at TP=8) instead of the full checkpoint sizes ([6144] and [1024]). The weight loader silently skipped loading them due to the shape mismatch, leaving both norms with all-ones weights regardless of calibration — effectively disabling qk_norm entirely across all 62 layers.

Fix: weights are now sized to match the checkpoint (total_num_heads * head_dim, replicated across TP ranks). The forward pass computes global variance via an all-reduce of per-rank sum-of-squares, then each rank multiplies by its contiguous head-slice of the weight vector.

After the fix this is gsm8k result

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match ↑ 0.8878 ± 0.0087
strict-match 5 exact_match ↑ 0.8863 ± 0.0087

Thank you for your great work.
Actually, this value is still not high enough. Could you try using vllm for a comparison accuracy test? This has always been what we've hoped for.
You can refer to acc result from this link:
https://huggingface.co/amd/MiniMax-M2.1-MXFP4
We also tested the GSM8K in this links, and the data reached 0.9+.

@thpereir thpereir force-pushed the thpereir/minimax21_mxfp4 branch from e11e392 to 6bb2ea5 Compare March 16, 2026 20:06
@ZhangLirong-amd
Copy link
Contributor

@thpereir ,Hi, I test your code for Minimax-M2.1 and Minimax-M2.1-MXFP4, and the accuracy is lower.

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.72 ± 0.0641
strict-match 5 exact_match 0.70 ± 0.0655

Can you help to check?

@lihaoyang-amd
Copy link

lihaoyang-amd commented Mar 17, 2026

@thpereir ,Hi, I test your code for Minimax-M2.1 and Minimax-M2.1-MXFP4, and the accuracy is lower.

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match ↑ 0.72 ± 0.0641
strict-match 5 exact_match ↑ 0.70 ± 0.0655
Can you help to check?

Hi, lirong, could you post the command you used?

@ZhangLirong-amd
Copy link
Contributor

ZhangLirong-amd commented Mar 17, 2026

python3 -m atom.entrypoints.openai_server --model /data/MiniMax-M2.5/ -tp 2 --port 5678 --server-port 7777 --kv_cache_dtype fp8

lm_eval --model local-completions         --model_args model=/data/MiniMax-M2.5/,base_url=http://localhost:7777/v1/completions,num_concurrent=8,max_retries=3,tokenized_requests=False         --tasks gsm8k    --limit 100

The FP4 model is the same, and both on tp2 and tp 4, and aiter is latest version.

@thpereir
Copy link
Contributor Author

thpereir commented Mar 17, 2026

@ZhangLirong-amd If you check our Quark MXFP4 model card here you will see they are not using plain lm-eval to get their 0.9348 score! They are using vllm's own gsm8k script that does extra things. The scores are not directly comparable! I read vllm script and it does extra things like using stop tokens stop=["Question", "Assistant:", "<|separator|>"] (instead of until=["\n\n"] from lm-eval) and also does answer extraction using re.findall(r"\d+", ...)[-1]

Regarding the original model, I don't see a GSM8K score on its model card. What command did you use to serve this model using vllm and compare to ATOM?

@ZhangLirong-amd
Copy link
Contributor

VLLM_ROCM_USE_AITER=1 vllm serve /data/MiniMax-M2.5/     -tp 2     --enforce-eager     --port 1234     --trust-remote-code

lm_eval --model local-completions        --model_args model=/data/MiniMax-M2.5/,base_url=http://localhost:1234/v1/completions,num_concurrent=8,max_retries=3,tokenized_requests=False         --tasks gsm8k    --limit 100

This is the script I tested on vllm, and I got 0.95 score. I think ATOM and vllm, I use the same lm_eval script.

And I also check the gsm8k results for ATOM, and it was indeed a calculation error in the math problem. I believe this issue needs to be addressed. Right now, I can only identify a difference Q,K we got in MiniMaxText01RMSNormTP before attention.

@niuxjamd niuxjamd requested a review from valarLip March 19, 2026 07:44
Introduces atom/models/minimax_m2.py with full support for MiniMax-M2.1
under ATOM's TP-parallel serving stack.

Architecture support:
- MiniMaxM2Attention: GQA with rotary embeddings (partial rotary, 50%)
  and optional qk_norm (enabled in M2.1)
- MiniMaxM2SparseMoeBlock: 256-expert sparse MoE, top-8, sigmoid routing
  with per-expert routing bias (use_routing_bias)
- MiniMaxM2DecoderLayer / MiniMaxM2Model / MiniMaxM2ForCausalLM
- Packed QKV mapping (q_proj/k_proj/v_proj -> qkv_proj) for weight loading
- Pipeline-parallel (PP) support via PPMissingLayer / IntermediateTensors
- Expert weight mapping: w1/w2/w3 -> gate/down/up proj

qk_norm: correct TP-distributed global RMSNorm
- q_norm and k_norm weights sized to match checkpoint:
    q_norm: [total_num_heads * head_dim] = [6144]  (replicated across TP ranks)
    k_norm: [total_num_kv_heads * head_dim] = [1024] (replicated across TP ranks)
- Forward computes global variance via all-reduce of per-rank sum-of-squares,
  then each rank applies its contiguous head-slice of the weight vector
- Handles kv_heads < tp_size by using the full k_norm weight without slicing
- Wrong implementation used per-rank sizes ([768]/[128] with TP=8), causing
  weight loading to silently skip the norms (shape mismatch) and leaving
  them at all-ones, which reduced GSM8K from ~0.87 to 0.10

Fix MoE routing to match vLLM: use grouped_topk, fp32 gate weights and router
logits, fix FusedMoE has_bias default to False, fix SwiGLU branch condition.
@thpereir thpereir force-pushed the thpereir/minimax21_mxfp4 branch from 6bb2ea5 to 01e6c1b Compare March 19, 2026 23:24
@thpereir
Copy link
Contributor Author

@ZhangLirong-amd I fixed the issue for tp=2. Please take a look again.

It was still weight loading for the q_norm/k_norm weights (per-layer RMSNorm). The sharding was incorrect

I was only able to fix the issue for tp=2. When using tp=4 or tp=8, padding needs to be added to the shards. This padding and weight shuffle are causing problems together. I am still investigating this.

Now we have parity with vllm in this case.

Server:

python -m atom.entrypoints.openai_server --model /mnt/dcgpuval/huggingface/amd/MiniMax-M2.1-MXFP4/ --trust-remote-code -tp 2

lm-eval

lm_eval --model local-completions \
   --model_args model=/mnt/dcgpuval/huggingface/amd/MiniMax-M2.1-MXFP4/,base_url=http://localhost:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False \
   --tasks gsm8k \
   --num_fewshot 5

Results:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9378|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.9333|±  |0.0069|

Let's merge this and then I will continue working on tp=4/tp=8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants