Skip to content

[QUARK-493] Fix Qwen3 MXFP4 MoE weight loading with TP 4/8#309

Open
thpereir wants to merge 1 commit intomainfrom
thpereir/qwen3_235
Open

[QUARK-493] Fix Qwen3 MXFP4 MoE weight loading with TP 4/8#309
thpereir wants to merge 1 commit intomainfrom
thpereir/qwen3_235

Conversation

@thpereir
Copy link
Contributor

@thpereir thpereir commented Mar 11, 2026

When loading MXFP4-quantized Qwen3-235B MoE weights with tensor parallelism (TP=4 or TP=8), the _load_w13 and _load_w2 methods crashed with:

RuntimeError: start (1536) + length (256) exceeds dimension size (1536)

Root cause: The .narrow() calls used expert_data.shape (which is padded to MXFP4 block alignment) to compute the per-TP-rank shard offset and size. For MXFP4, expert_data is rounded up to the next block boundary, making shard_size larger than the actual loaded weight dimension. Multiplying this padded shard_size by tp_rank produced an out-of-bounds start index.

Fix: Compute loaded_shard_size from the loaded_weight tensor's actual dimension (loaded_weight.shape[shard_dim] // tp_size) instead of the padded expert_data dimension. Use loaded_shard_size for the .narrow() on loaded_weight. When copying into expert_data, narrow expert_data to loaded_shard_size if it exceeds the loaded shard, ensuring only the valid (unpadded) region is written.

Applied to both _load_w13 (gate/up projection) and _load_w2 (down projection) code paths.

Motivation

Technical Details

Test Plan

Test Result

Server

python -m atom.entrypoints.openai_server --model /data/huggingface/models/amd/Qwen3-235B-A22B-Instruct-2507-MXFP4/ --trust-remote-code -tp 4 --kv_cache_dtype fp8

lm-eval

lm_eval --model local-completions   --model_args "model=/data/huggingface/models/amd/Qwen3-235B-A22B-Instruct-2507-MXFP4/,base_url=http://0.0.0.0:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=32"   --tasks gsm8k   --num_fewshot 5   --batch_size 1
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9067 ± 0.0080
strict-match 5 exact_match 0.8961 ± 0.0084

Submission Checklist

valarLip
valarLip previously approved these changes Mar 12, 2026
When loading MXFP4-quantized Qwen3-235B MoE weights with tensor parallelism
(TP=4 or TP=8), the _load_w13 and _load_w2 methods crashed with:

  RuntimeError: start (1536) + length (256) exceeds dimension size (1536)

Root cause: The .narrow() calls used expert_data.shape (which is padded to
MXFP4 block alignment) to compute the per-TP-rank shard offset and size.
For MXFP4, expert_data is rounded up to the next block boundary, making
shard_size larger than the actual loaded weight dimension. Multiplying
this padded shard_size by tp_rank produced an out-of-bounds start index.

Fix: Compute loaded_shard_size from the loaded_weight tensor's actual
dimension (loaded_weight.shape[shard_dim] // tp_size) instead of the
padded expert_data dimension. Use loaded_shard_size for the .narrow()
on loaded_weight. When copying into expert_data, narrow expert_data to
loaded_shard_size if it exceeds the loaded shard, ensuring only the
valid (unpadded) region is written.

Applied to both _load_w13 (gate/up projection) and _load_w2 (down
projection) code paths.
@thpereir
Copy link
Contributor Author

Looks like the same fix was merged into main before I merged this PR. I will then keep the unit tests if that's ok

@haoyangli0109
Copy link
Contributor

Looks like the same fix was merged into main before I merged this PR. I will then keep the unit tests if that's ok

hi, thiago, you can click the re-run button to restart the test. It appears the CI failure is not due to this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants