refactor: unify RMSNorm fusion with DualRMSNorm + master switch#342
Draft
refactor: unify RMSNorm fusion with DualRMSNorm + master switch#342
Conversation
…MSNorm Phase 1 of RMSNorm fusion unification: - Move _fuse_rmsnorm_fp4_quant, _fused_rms_fp8_group_quant, and their fake-tensor versions from deepseek_v2.py to model_ops/layernorm.py - Add public fuse_rmsnorm_group_quant() dispatcher - Extend RMSNorm with quant_config-based auto-routing: group-quant path selected automatically from quant_type + params_dtype - Add transpose_scale and shuffle constructor parameters - Add DualRMSNorm class for fused dual-norm (q_a + kv_a in MLA)
Phase 2 of RMSNorm fusion unification: - DecoderLayer.forward(): replace 44-line _fuse_rmsnorm_quant bypass with 7-line self.input_layernorm() call (RMSNorm handles routing internally) - DecoderLayer.__init__(): pass fused_quant + quant_config to RMSNorm - MLAAttention: add DualRMSNorm for non-triton-GEMM QK-norm fusion, replacing 15-line _fuse_rmsnorm_quant(x2=kv_c) with 3-line call - Triton GEMM QKV-projection fusion path unchanged (MLA-specific)
…dead code Phase 3+4 of RMSNorm fusion unification: - envs.py: add ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch (default ON) Old per-model vars fallback to master switch when not explicitly set - deepseek_v2.py: delete ~210 lines of private _fuse_rmsnorm_quant functions and fake-tensor versions (now in layernorm.py) - Remove unused self.fuse_rmsnorm_quant flag from DecoderLayer
…ights DualRMSNorm previously created weight1/weight2 parameters that didn't match checkpoint keys (qk_layernorm.weight1 vs q_a_layernorm.weight). This caused weights to stay at initial 1.0 values, producing garbage output (GSM8K 0.0). Fix: DualRMSNorm now takes existing RMSNorm modules as constructor arguments and uses their .weight and .eps directly. No duplicate parameters, checkpoint loading works correctly.
- Remove use_triton_gemm() guard for DualRMSNorm: FP4 models now use fused QK-norm + quant via DualRMSNorm regardless of GEMM backend - Update comments/docs to match: clarify input_layernorm still requires triton GEMM while QK-norm DualRMSNorm does not - Add DualRMSNorm to model_ops_guide (source table, normalization section, fused kernel chains table) - Add ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch to env var docs, mark old per-model vars as deprecated - Fix DualRMSNorm docstring terminology consistency Verified: DeepSeek BF16 GSM8K 0.957, DeepSeek MXFP4 GSM8K 0.948
Contributor
There was a problem hiding this comment.
Pull request overview
This PR consolidates DeepSeek-specific RMSNorm+quant fusion logic into shared atom/model_ops/layernorm.py abstractions (DualRMSNorm + shared group-quant dispatch) and introduces a repo-wide master env switch (ATOM_ENABLE_RMSNORM_QUANT_FUSION) to control all RMSNorm+quant fusion paths.
Changes:
- Add
DualRMSNormand move FP8/FP4 group-quant fused RMSNorm dispatch/wrappers intolayernorm.py, extendingRMSNormquant routing. - Refactor
deepseek_v2.pyto useDualRMSNorm/ unifiedRMSNorm.forward()paths, removing large blocks of model-local fusion code. - Add
ATOM_ENABLE_RMSNORM_QUANT_FUSIONand deprecate older per-model toggles with fallback behavior; update docs accordingly.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
atom/model_ops/layernorm.py |
Introduces DualRMSNorm, adds group-quant fusion dispatcher, and refactors RMSNorm fused-quant routing. |
atom/models/deepseek_v2.py |
Migrates DeepSeek MLA QK-norm and input-layernorm fusion to DualRMSNorm/RMSNorm APIs and deletes model-local wrappers. |
atom/utils/envs.py |
Adds ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch and makes older DS/Llama switches deprecated overrides. |
docs/environment_variables.md |
Documents the new master switch and DS deprecated overrides. |
docs/model_ops_guide.md |
Adds DualRMSNorm to operator docs and fusion table. |
Comments suppressed due to low confidence (2)
docs/environment_variables.md:72
- The Llama RMSNorm fusion env var docs still show
ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANTas non-deprecated with a fixed default of1 (true), butatom/utils/envs.pynow makes it a deprecated override that falls back toATOM_ENABLE_RMSNORM_QUANT_FUSIONwhen unset. Please update this section to reflect the new defaulting/deprecation behavior so users don’t get surprised when the master switch disables Llama fusion too.
| **ATOM_ENABLE_RMSNORM_QUANT_FUSION** | bool | 1 (true) | Master switch for all RMSNorm + quantization fusion paths (all models). |
| **ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek input layernorm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. |
| **ATOM_ENABLE_DS_QKNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek QK-norm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. |
### Qwen3-MoE style
| Variable | Type | Default | Description |
|----------|------|---------|-------------|
| **ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION** | bool | 0 (false) | If set to `1`, fuse QK norm, RoPE, and cache quantization into one kernel. **Enable this for Qwen3-MoE models for better performance.** |
### Llama-style
| Variable | Type | Default | Description |
|----------|------|---------|-------------|
| **ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT** | bool | 1 (true) | If set to `1`, use Triton kernel to fuse RMSNorm with quantization. |
| **ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT** | bool | 1 (true) | If set to `1`, use Triton kernel to fuse SiLU and mul with quantization in MLP module. |
atom/model_ops/layernorm.py:448
RMSNorm.forward()can return(x, residual)(plain/allreduce/pad paths) and((x_quant, x_scale), residual)(fused quant + residual), but the return type annotation only allowsTensor | tuple[Tensor, Tensor]. Update the annotation to reflect the actual nested-tuple variants (or refactor to a consistent return type) so callers and type-checkers don't get misled.
@mark_trace(prefix="rmsnorm", torch_compile=True)
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
x_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Comment on lines
333
to
354
| `RMSNorm` supports multiple forward paths depending on configuration flags: | ||
|
|
||
| | Condition | Kernel / Path | Returns | | ||
| |---|---|---| | ||
| | `x_pad_to_multiple > 0`, no residual | `fused_rmsnorm_pad_` (Triton `fused_add_rmsnorm_pad`) | Padded output | | ||
| | `x_pad_to_multiple > 0`, with residual | `fused_add_rmsnorm_pad_` | (output, residual) | | ||
| | `fused_allreduce=True` and `tp_size > 1` | `tensor_model_parallel_fused_allreduce_rmsnorm` | (output, residual) | | ||
| | `fused_quant=True` and `x_scale` provided | `fused_rms_fp8_per_tensor_static_quant` | (FP8 output, scale) | | ||
| | `fused_quant=True` and `per_1x32` | `fused_rms_mxfp4_quant` | (MXFP4 output, scale) | | ||
| | Default, no residual | `rmsnorm2d_fwd` | Output | | ||
| | Default, with residual | `rmsnorm2d_fwd_with_add` | (output, residual) | | ||
|
|
||
| Constructor parameters: | ||
| ```python | ||
| RMSNorm( | ||
| dim: int, | ||
| eps: float = 1e-6, | ||
| x_pad_to_multiple: int = 0, | ||
| fused_allreduce: bool = False, | ||
| fused_quant: bool = False, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| ) |
| from atom.model_ops.base_attention import Attention | ||
| from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding | ||
| from atom.model_ops.layernorm import LayerNorm, RMSNorm | ||
| from atom.model_ops.layernorm import DualRMSNorm, LayerNorm, RMSNorm # noqa: F401 |
Comment on lines
1213
to
+1221
| self.fuse_qknorm_quant = True | ||
| # DualRMSNorm: fused dual norm + quant for both FP8 and FP4 paths | ||
| self.qk_layernorm = DualRMSNorm( | ||
| self.q_a_layernorm, | ||
| self.kv_a_layernorm, | ||
| quant_config=quant_config, | ||
| transpose_scale=True, | ||
| shuffle=False, | ||
| ) |
| | `MLAAttention` | `attention_mla.py` | `mla_decode_fwd`, `mla_prefill_fwd`, `concat_and_cache_mla`, `fused_qk_rope_concat_and_cache_mla` | Multi-head latent attention | | ||
| | `FusedMoE` | `moe.py` | `aiter.fused_moe.fused_moe`, `asm_moe` | Mixture of experts | | ||
| | `RMSNorm` | `layernorm.py` | `rmsnorm2d_fwd`, `rmsnorm2d_fwd_with_add`, `fused_add_rmsnorm_pad` | RMS normalization | | ||
| | `DualRMSNorm` | `layernorm.py` | `fuse_rmsnorm_group_quant` | Fused dual RMSNorm + quant (MLA q/kv norms) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Unify RMSNorm + quantization fusion into
DualRMSNormclass withATOM_ENABLE_RMSNORM_QUANT_FUSIONmaster switch.Changes
layernorm.py— add DualRMSNorm, move group-quant kernel wrappersdeepseek_v2.py— migrate to RMSNorm/DualRMSNorm, remove 300+ linesenvs.py— add ATOM_ENABLE_RMSNORM_QUANT_FUSIONdocs/— update environment_variables.md + model_ops_guide.md5 commits, 5 files. No CI/workflow changes.
Split from #334 (model code part only).