Skip to content

refactor: unify RMSNorm fusion with DualRMSNorm + master switch#342

Draft
valarLip wants to merge 5 commits intomainfrom
refactor/rmsnorm-fusion-v2
Draft

refactor: unify RMSNorm fusion with DualRMSNorm + master switch#342
valarLip wants to merge 5 commits intomainfrom
refactor/rmsnorm-fusion-v2

Conversation

@valarLip
Copy link
Collaborator

Summary

Unify RMSNorm + quantization fusion into DualRMSNorm class with ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch.

Changes

  • layernorm.py — add DualRMSNorm, move group-quant kernel wrappers
  • deepseek_v2.py — migrate to RMSNorm/DualRMSNorm, remove 300+ lines
  • envs.py — add ATOM_ENABLE_RMSNORM_QUANT_FUSION
  • docs/ — update environment_variables.md + model_ops_guide.md

5 commits, 5 files. No CI/workflow changes.

Split from #334 (model code part only).

…MSNorm

Phase 1 of RMSNorm fusion unification:
- Move _fuse_rmsnorm_fp4_quant, _fused_rms_fp8_group_quant, and their
  fake-tensor versions from deepseek_v2.py to model_ops/layernorm.py
- Add public fuse_rmsnorm_group_quant() dispatcher
- Extend RMSNorm with quant_config-based auto-routing: group-quant path
  selected automatically from quant_type + params_dtype
- Add transpose_scale and shuffle constructor parameters
- Add DualRMSNorm class for fused dual-norm (q_a + kv_a in MLA)
Phase 2 of RMSNorm fusion unification:
- DecoderLayer.forward(): replace 44-line _fuse_rmsnorm_quant bypass with
  7-line self.input_layernorm() call (RMSNorm handles routing internally)
- DecoderLayer.__init__(): pass fused_quant + quant_config to RMSNorm
- MLAAttention: add DualRMSNorm for non-triton-GEMM QK-norm fusion,
  replacing 15-line _fuse_rmsnorm_quant(x2=kv_c) with 3-line call
- Triton GEMM QKV-projection fusion path unchanged (MLA-specific)
…dead code

Phase 3+4 of RMSNorm fusion unification:
- envs.py: add ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch (default ON)
  Old per-model vars fallback to master switch when not explicitly set
- deepseek_v2.py: delete ~210 lines of private _fuse_rmsnorm_quant functions
  and fake-tensor versions (now in layernorm.py)
- Remove unused self.fuse_rmsnorm_quant flag from DecoderLayer
…ights

DualRMSNorm previously created weight1/weight2 parameters that didn't
match checkpoint keys (qk_layernorm.weight1 vs q_a_layernorm.weight).
This caused weights to stay at initial 1.0 values, producing garbage
output (GSM8K 0.0).

Fix: DualRMSNorm now takes existing RMSNorm modules as constructor
arguments and uses their .weight and .eps directly. No duplicate
parameters, checkpoint loading works correctly.
- Remove use_triton_gemm() guard for DualRMSNorm: FP4 models now use
  fused QK-norm + quant via DualRMSNorm regardless of GEMM backend
- Update comments/docs to match: clarify input_layernorm still requires
  triton GEMM while QK-norm DualRMSNorm does not
- Add DualRMSNorm to model_ops_guide (source table, normalization
  section, fused kernel chains table)
- Add ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch to env var docs,
  mark old per-model vars as deprecated
- Fix DualRMSNorm docstring terminology consistency

Verified: DeepSeek BF16 GSM8K 0.957, DeepSeek MXFP4 GSM8K 0.948
Copilot AI review requested due to automatic review settings March 16, 2026 05:58
@valarLip valarLip marked this pull request as draft March 16, 2026 05:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR consolidates DeepSeek-specific RMSNorm+quant fusion logic into shared atom/model_ops/layernorm.py abstractions (DualRMSNorm + shared group-quant dispatch) and introduces a repo-wide master env switch (ATOM_ENABLE_RMSNORM_QUANT_FUSION) to control all RMSNorm+quant fusion paths.

Changes:

  • Add DualRMSNorm and move FP8/FP4 group-quant fused RMSNorm dispatch/wrappers into layernorm.py, extending RMSNorm quant routing.
  • Refactor deepseek_v2.py to use DualRMSNorm / unified RMSNorm.forward() paths, removing large blocks of model-local fusion code.
  • Add ATOM_ENABLE_RMSNORM_QUANT_FUSION and deprecate older per-model toggles with fallback behavior; update docs accordingly.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
atom/model_ops/layernorm.py Introduces DualRMSNorm, adds group-quant fusion dispatcher, and refactors RMSNorm fused-quant routing.
atom/models/deepseek_v2.py Migrates DeepSeek MLA QK-norm and input-layernorm fusion to DualRMSNorm/RMSNorm APIs and deletes model-local wrappers.
atom/utils/envs.py Adds ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch and makes older DS/Llama switches deprecated overrides.
docs/environment_variables.md Documents the new master switch and DS deprecated overrides.
docs/model_ops_guide.md Adds DualRMSNorm to operator docs and fusion table.
Comments suppressed due to low confidence (2)

docs/environment_variables.md:72

  • The Llama RMSNorm fusion env var docs still show ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT as non-deprecated with a fixed default of 1 (true), but atom/utils/envs.py now makes it a deprecated override that falls back to ATOM_ENABLE_RMSNORM_QUANT_FUSION when unset. Please update this section to reflect the new defaulting/deprecation behavior so users don’t get surprised when the master switch disables Llama fusion too.
| **ATOM_ENABLE_RMSNORM_QUANT_FUSION** | bool | 1 (true) | Master switch for all RMSNorm + quantization fusion paths (all models). |
| **ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek input layernorm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. |
| **ATOM_ENABLE_DS_QKNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek QK-norm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. |

### Qwen3-MoE style

| Variable | Type | Default | Description |
|----------|------|---------|-------------|
| **ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION** | bool | 0 (false) | If set to `1`, fuse QK norm, RoPE, and cache quantization into one kernel. **Enable this for Qwen3-MoE models for better performance.** |

### Llama-style

| Variable | Type | Default | Description |
|----------|------|---------|-------------|
| **ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT** | bool | 1 (true) | If set to `1`, use Triton kernel to fuse RMSNorm with quantization. |
| **ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT** | bool | 1 (true) | If set to `1`, use Triton kernel to fuse SiLU and mul with quantization in MLP module. |

atom/model_ops/layernorm.py:448

  • RMSNorm.forward() can return (x, residual) (plain/allreduce/pad paths) and ((x_quant, x_scale), residual) (fused quant + residual), but the return type annotation only allows Tensor | tuple[Tensor, Tensor]. Update the annotation to reflect the actual nested-tuple variants (or refactor to a consistent return type) so callers and type-checkers don't get misled.
    @mark_trace(prefix="rmsnorm", torch_compile=True)
    def forward(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
        x_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines 333 to 354
`RMSNorm` supports multiple forward paths depending on configuration flags:

| Condition | Kernel / Path | Returns |
|---|---|---|
| `x_pad_to_multiple > 0`, no residual | `fused_rmsnorm_pad_` (Triton `fused_add_rmsnorm_pad`) | Padded output |
| `x_pad_to_multiple > 0`, with residual | `fused_add_rmsnorm_pad_` | (output, residual) |
| `fused_allreduce=True` and `tp_size > 1` | `tensor_model_parallel_fused_allreduce_rmsnorm` | (output, residual) |
| `fused_quant=True` and `x_scale` provided | `fused_rms_fp8_per_tensor_static_quant` | (FP8 output, scale) |
| `fused_quant=True` and `per_1x32` | `fused_rms_mxfp4_quant` | (MXFP4 output, scale) |
| Default, no residual | `rmsnorm2d_fwd` | Output |
| Default, with residual | `rmsnorm2d_fwd_with_add` | (output, residual) |

Constructor parameters:
```python
RMSNorm(
dim: int,
eps: float = 1e-6,
x_pad_to_multiple: int = 0,
fused_allreduce: bool = False,
fused_quant: bool = False,
quant_config: Optional[QuantizationConfig] = None,
)
from atom.model_ops.base_attention import Attention
from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding
from atom.model_ops.layernorm import LayerNorm, RMSNorm
from atom.model_ops.layernorm import DualRMSNorm, LayerNorm, RMSNorm # noqa: F401
Comment on lines 1213 to +1221
self.fuse_qknorm_quant = True
# DualRMSNorm: fused dual norm + quant for both FP8 and FP4 paths
self.qk_layernorm = DualRMSNorm(
self.q_a_layernorm,
self.kv_a_layernorm,
quant_config=quant_config,
transpose_scale=True,
shuffle=False,
)
| `MLAAttention` | `attention_mla.py` | `mla_decode_fwd`, `mla_prefill_fwd`, `concat_and_cache_mla`, `fused_qk_rope_concat_and_cache_mla` | Multi-head latent attention |
| `FusedMoE` | `moe.py` | `aiter.fused_moe.fused_moe`, `asm_moe` | Mixture of experts |
| `RMSNorm` | `layernorm.py` | `rmsnorm2d_fwd`, `rmsnorm2d_fwd_with_add`, `fused_add_rmsnorm_pad` | RMS normalization |
| `DualRMSNorm` | `layernorm.py` | `fuse_rmsnorm_group_quant` | Fused dual RMSNorm + quant (MLA q/kv norms) |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants