Skip to content

mla nps fp8 mode get_meta_param avoid kv_tail_len < max_seqlen_q and fix nhead=128 reduce mgc#2319

Open
minmengdie wants to merge 3 commits intomainfrom
mmd/fix/mla_nps_fp8_128_2
Open

mla nps fp8 mode get_meta_param avoid kv_tail_len < max_seqlen_q and fix nhead=128 reduce mgc#2319
minmengdie wants to merge 3 commits intomainfrom
mmd/fix/mla_nps_fp8_128_2

Conversation

@minmengdie
Copy link
Contributor

@minmengdie minmengdie commented Mar 18, 2026

Motivation

mla nps fp8 mode get_meta_param avoid kv_tail_len < max_seqlen_q

Technical Details

Test Plan

python3 op_tests/test_mla.py -b 1 --nhead 128,4 -d fp8 -kvd fp8 -splits 16 -c $(seq 4 200 | tr '\n' ' ')
rocm/atom:nightly_202601190317
用这个docker搭之后, cd /app ,里面就有aiter和atom代码
cd ATOM
fetch最新之后
git checkout zlr/mtp_dp

pip install .
MORI_SHMEM_MODE=ISOLATION python3 -m atom.examples.simple_inference --model /data/DeepSeek-R1-0528/ -tp 8 --port 5678 --kv_cache_dtype fp8 --torch-profiler-dir ./log --method mtp --num-speculative-tokens 3 --block-size 10000 --gpu-memory-utilization 0.41 --enable-dp-attention --enable-expert-parallel

Test Result

image image image

Submission Checklist

@minmengdie minmengdie requested review from a team and Copilot March 18, 2026 05:54
@github-actions
Copy link
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2319 --add-label <label>

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates MLA decode heuristics for FP8/persistent-split (PS) mode to better bound num_kv_splits when KV tail is short relative to max_seqlen_q, and tweaks decode reduction granularity (mgc) for the nhead==128 FP8 case. It also adjusts the MLA op test to exercise an additional head configuration and a padded kv_indices layout.

Changes:

  • Add an FP8-specific upper bound on num_kv_splits in get_meta_param based on total_kv/bs - max_seqlen_q.
  • Change mgc selection to use mgc=32 for nhead==128 with FP8 KV buffer.
  • Update op_tests/test_mla.py to (a) pad kv_indices length and (b) allow (128, 4) in CLI --nhead choices.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
aiter/mla.py Adjusts FP8 split-count meta-parameter logic and mgc heuristic used by MLA decode stage2 reduction.
op_tests/test_mla.py Modifies MLA test inputs (padded kv_indices) and expands allowed --nhead choices.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@minmengdie minmengdie force-pushed the mmd/fix/mla_nps_fp8_128_2 branch from 011c7f7 to e212b0a Compare March 18, 2026 06:36
@minmengdie minmengdie force-pushed the mmd/fix/mla_nps_fp8_128_2 branch from 502b9d0 to 9deb64b Compare March 18, 2026 06:56
@minmengdie minmengdie changed the title mla ps fp8 mode get_meta_param avoid kv_tail_len < max_seqlen_q and reduce mgc mla ps fp8 mode get_meta_param avoid kv_tail_len < max_seqlen_q and fix nhead=128 reduce mgc Mar 18, 2026
@minmengdie minmengdie requested review from shengnxu and valarLip March 18, 2026 10:23
@minmengdie minmengdie changed the title mla ps fp8 mode get_meta_param avoid kv_tail_len < max_seqlen_q and fix nhead=128 reduce mgc mla nps fp8 mode get_meta_param avoid kv_tail_len < max_seqlen_q and fix nhead=128 reduce mgc Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants