From e8900aa2b37f18ea5b0f338d371004c3c99f8123 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 15 Mar 2026 05:13:04 +0000 Subject: [PATCH 1/5] refactor: move group-quant kernel wrappers to layernorm.py, add DualRMSNorm Phase 1 of RMSNorm fusion unification: - Move _fuse_rmsnorm_fp4_quant, _fused_rms_fp8_group_quant, and their fake-tensor versions from deepseek_v2.py to model_ops/layernorm.py - Add public fuse_rmsnorm_group_quant() dispatcher - Extend RMSNorm with quant_config-based auto-routing: group-quant path selected automatically from quant_type + params_dtype - Add transpose_scale and shuffle constructor parameters - Add DualRMSNorm class for fused dual-norm (q_a + kv_a in MLA) --- atom/model_ops/layernorm.py | 429 +++++++++++++++++++++++++++++++----- 1 file changed, 371 insertions(+), 58 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 7f556fcc5..6c2501255 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -25,6 +25,12 @@ from aiter import ( QuantType, ) +from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE + +try: + from aiter import dtypes as _aiter_dtypes +except ImportError: + _aiter_dtypes = None def silu(input: Tensor, inplace: bool = False) -> Tensor: @@ -169,6 +175,216 @@ def mxfp4_rms_quant_fuse( return x_quant, x_scale, residual_out +# --------------------------------------------------------------------------- +# Group-quant fused RMSNorm kernels (moved from deepseek_v2.py) +# --------------------------------------------------------------------------- + + +def _fuse_rmsnorm_fp4_quant_fake( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + shuffle: bool = True, + scale_shuffle_padding: bool = True, + output_unquantized_inp1: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + m, n1 = x1.shape + n2 = x2.shape[1] if x2 is not None else 0 + + out1_quantized = torch.empty((m, n1 // 2), dtype=torch.uint8, device=x1.device) + scale_n_valid = (n1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + scale_m = ((m + 255) // 256) * 256 + scale_n = ((scale_n_valid + 7) // 8) * 8 + out1_bs = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=x1.device) + + out2 = None + if x2 is not None: + out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) + out_res1 = None + if res1 is not None: + out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) + out1_unquantized = None + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +def _fused_rms_fp8_group_quant_fake( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: "torch.dtype | None" = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if dtype_quant is None and _aiter_dtypes is not None: + dtype_quant = _aiter_dtypes.fp8 + m, n1 = x1.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device) + out1_bs = torch.empty( + (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=x1.device + ) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out1_unquantized = None + if output_unquantized_inp1: + out1_unquantized = torch.empty_like(x1) + out2 = None + if x2 is not None: + _, n2 = x2.shape + out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) + out_res1 = None + if res1 is not None: + out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +@torch_compile_guard(gen_fake=_fuse_rmsnorm_fp4_quant_fake) +def _fuse_rmsnorm_fp4_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + shuffle: bool = True, + scale_shuffle_padding: bool = True, + output_unquantized_inp1: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + + m = x1.shape[0] + shuffle_bool = shuffle and (m >= MXFP4_QUANT_BLOCK_SIZE) + + (out1_quantized, out1_bs), _out1_unquantized, out2, out_res1 = ( + fused_rms_mxfp4_quant( + x1=x1, + x1_weight=x1_weight, + x1_epsilon=x1_epsilon, + x2=x2, + x2_weight=x2_weight, + x2_epsilon=0.0 if x2_epsilon is None else x2_epsilon, + res1=res1, + shuffle=shuffle_bool, + scale_shuffle_padding=scale_shuffle_padding, + output_unquantized_inp1=output_unquantized_inp1, + ) + ) + out1_unquantized = None + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +@torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake) +def _fused_rms_fp8_group_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: "torch.dtype | None" = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + if dtype_quant is None and _aiter_dtypes is not None: + dtype_quant = _aiter_dtypes.fp8 + + (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = ( + fused_rms_fp8_group_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + group_size, + dtype_quant, + res1, + output_unquantized_inp1, + transpose_scale, + ) + ) + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +def fuse_rmsnorm_group_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: "torch.dtype | None" = None, + shuffle: bool = True, + scale_shuffle_padding: bool = False, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +): + """Dispatch fused RMSNorm + group quantization to the correct kernel. + + Supports FP8 per-group and MXFP4 quantization. Optionally normalizes a + second tensor (x2) and fuses residual-add (res1) in the same kernel call. + + Returns: (out1_quantized, out1_scale), out1_unquantized, out2_normed, residual_out + """ + if _aiter_dtypes is None: + raise RuntimeError("aiter.dtypes not available") + + if dtype_quant is None: + dtype_quant = _aiter_dtypes.fp8 + + if dtype_quant == _aiter_dtypes.fp4x2: + out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( + _fuse_rmsnorm_fp4_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + res1, + shuffle, + scale_shuffle_padding, + output_unquantized_inp1, + ) + ) + elif dtype_quant == _aiter_dtypes.fp8: + out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( + _fused_rms_fp8_group_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + res1, + dtype_quant, + group_size, + output_unquantized_inp1, + transpose_scale, + ) + ) + else: + raise ValueError( + f"No fused rmsnorm quant kernel available for quant dtype: {dtype_quant}." + ) + return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 + + class RMSNorm(nn.Module): def __init__( self, @@ -178,6 +394,8 @@ def __init__( fused_allreduce: bool = False, fused_quant: bool = False, quant_config: Optional[QuantizationConfig] = None, + transpose_scale: bool = False, + shuffle: bool = True, ) -> None: super().__init__() self.dim = dim @@ -187,6 +405,8 @@ def __init__( self.fused_allreduce = fused_allreduce self.use_fused_quant = fused_quant self.tp_size = get_tensor_model_parallel_world_size() + self.transpose_scale = transpose_scale + self.shuffle = shuffle layer_quant_config = ( LayerQuantConfig() @@ -198,6 +418,27 @@ def __init__( self.quant_type = quant_type self.params_dtype = params_dtype + # Determine the fused quant path based on quant_config: + # - "group": FP8 per-group (per_1x128/per_Token) or MXFP4 group (fp4x2) + # - "per_tensor": FP8 per-tensor static (requires x_scale at forward time) + # - "simple_mxfp4": MXFP4 simple (per_1x32 without group quant) + # - None: no quantization fusion + self._quant_path = None + if fused_quant and _aiter_dtypes is not None: + if params_dtype == _aiter_dtypes.fp8 and quant_type in ( + QuantType.per_1x128, + QuantType.per_Token, + ): + self._quant_path = "group" + elif ( + params_dtype == _aiter_dtypes.fp8 and quant_type == QuantType.per_Tensor + ): + self._quant_path = "per_tensor" + elif params_dtype == _aiter_dtypes.fp4x2: + self._quant_path = "group" + elif quant_type == QuantType.per_1x32: + self._quant_path = "simple_mxfp4" + @mark_trace(prefix="rmsnorm", torch_compile=True) def forward( self, @@ -228,66 +469,138 @@ def forward( self.eps, ) return x, residual - else: - if x_scale is not None and self.use_fused_quant: - from aiter.ops.triton.fused_fp8_quant import ( - fused_rms_fp8_per_tensor_static_quant, + + # --- Fused quant paths (dispatched by _quant_path) --- + if self._quant_path == "group": + # FP8 per-group or MXFP4 group quantization + (x_quant, x_scale), _, _, res_out = fuse_rmsnorm_group_quant( + x, + self.weight, + self.eps, + res1=residual, + dtype_quant=self.params_dtype, + shuffle=self.shuffle, + scale_shuffle_padding=self.shuffle, + group_size=128, + transpose_scale=self.transpose_scale, + ) + if residual is None: + return (x_quant, x_scale) + else: + return (x_quant, x_scale), res_out + + if self._quant_path == "per_tensor" and x_scale is not None: + from aiter.ops.triton.fused_fp8_quant import ( + fused_rms_fp8_per_tensor_static_quant, + ) + import aiter as rocm_aiter + + rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 + if residual is None: + x, _, _, _ = fused_rms_fp8_per_tensor_static_quant( + x, + self.weight, + self.eps, + x_scale, + None, + None, + self.eps, + dtype_quant=rocm_aiter_fp8_dtype, + res1=None, ) - import aiter as rocm_aiter - - rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 - - # static FP8 quantization - if residual is None: - x, _, _, _ = fused_rms_fp8_per_tensor_static_quant( - x, - self.weight, - self.eps, - x_scale, - None, - None, - self.eps, - dtype_quant=rocm_aiter_fp8_dtype, - res1=None, - ) - return (x, x_scale) - else: - x, _, _, residual = fused_rms_fp8_per_tensor_static_quant( - x, - self.weight, - self.eps, - x_scale, - None, - None, - self.eps, - dtype_quant=rocm_aiter_fp8_dtype, - res1=residual, - ) - return (x, x_scale), residual - elif self.use_fused_quant and ( - x_scale is None and self.quant_type.value == QuantType.per_1x32.value - ): - if residual is None: - x, x_scale, _ = mxfp4_rms_quant_fuse( - x, self.weight, self.eps, shuffle=True - ) - return x, x_scale - else: - x, x_scale, residual = mxfp4_rms_quant_fuse( - x, self.weight, self.eps, shuffle=True, res1=residual - ) - return (x, x_scale), residual + return (x, x_scale) + else: + x, _, _, residual = fused_rms_fp8_per_tensor_static_quant( + x, + self.weight, + self.eps, + x_scale, + None, + None, + self.eps, + dtype_quant=rocm_aiter_fp8_dtype, + res1=residual, + ) + return (x, x_scale), residual + + if self._quant_path == "simple_mxfp4": + if residual is None: + x, x_scale, _ = mxfp4_rms_quant_fuse( + x, self.weight, self.eps, shuffle=self.shuffle + ) + return x, x_scale else: - if residual is None: - # return rmsnorm2d_fwd(x, self.weight, self.eps).view(ori_shape) - x = rmsnorm2d_fwd_(x, self.weight, self.eps, self.dim) - return x - else: - # return self.add_rms_forward(x, residual) - x, residual = rmsnorm2d_fwd_with_add_( - x, self.weight, residual, self.eps, self.dim - ) - return x, residual + x, x_scale, residual = mxfp4_rms_quant_fuse( + x, self.weight, self.eps, shuffle=self.shuffle, res1=residual + ) + return (x, x_scale), residual + + # --- Plain RMSNorm (no fusion) --- + if residual is None: + x = rmsnorm2d_fwd_(x, self.weight, self.eps, self.dim) + return x + else: + x, residual = rmsnorm2d_fwd_with_add_( + x, self.weight, residual, self.eps, self.dim + ) + return x, residual + + +class DualRMSNorm(nn.Module): + """Fused dual RMSNorm + quantization for two inputs (e.g. q_c + kv_c). + + Uses a single AITER kernel call to normalize both tensors and quantize + the first, reducing kernel launch overhead vs two separate RMSNorm calls. + Typically used in MLA attention for the q_a + kv_a layernorms. + """ + + def __init__( + self, + dim1: int, + dim2: int, + eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + transpose_scale: bool = False, + shuffle: bool = False, + ) -> None: + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + self.eps = eps + self.weight1 = nn.Parameter(torch.ones(dim1)) + self.weight2 = nn.Parameter(torch.ones(dim2)) + self.transpose_scale = transpose_scale + self.shuffle = shuffle + + layer_quant_config = ( + LayerQuantConfig() + if quant_config is None + else quant_config.global_quant_config + ) + self.params_dtype = layer_quant_config["quant_dtype"] + + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + ) -> tuple: + """Normalize x1 and x2, quantize x1. + + Returns: (x1_quant, x1_scale), x2_normed + """ + (x1_quant, x1_scale), _, x2_normed, _ = fuse_rmsnorm_group_quant( + x1, + self.weight1, + self.eps, + x2=x2, + x2_weight=self.weight2, + x2_epsilon=self.eps, + dtype_quant=self.params_dtype, + shuffle=self.shuffle, + group_size=128, + transpose_scale=self.transpose_scale, + ) + return (x1_quant, x1_scale), x2_normed class RMSNormGated(nn.Module): From 39bea884d7ce15c6e6d3cf71f3fa0705423d2783 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 15 Mar 2026 05:15:53 +0000 Subject: [PATCH 2/5] refactor: migrate DecoderLayer + MLAAttention to use RMSNorm/DualRMSNorm Phase 2 of RMSNorm fusion unification: - DecoderLayer.forward(): replace 44-line _fuse_rmsnorm_quant bypass with 7-line self.input_layernorm() call (RMSNorm handles routing internally) - DecoderLayer.__init__(): pass fused_quant + quant_config to RMSNorm - MLAAttention: add DualRMSNorm for non-triton-GEMM QK-norm fusion, replacing 15-line _fuse_rmsnorm_quant(x2=kv_c) with 3-line call - Triton GEMM QKV-projection fusion path unchanged (MLA-specific) --- atom/models/deepseek_v2.py | 90 ++++++++++---------------------------- 1 file changed, 22 insertions(+), 68 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d1da9f052..849f781e7 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -61,7 +61,7 @@ from atom.model_ops.attention_mla import MLAModules, is_rocm_aiter_fp4bmm_enabled from atom.model_ops.base_attention import Attention from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding -from atom.model_ops.layernorm import LayerNorm, RMSNorm +from atom.model_ops.layernorm import DualRMSNorm, LayerNorm, RMSNorm # noqa: F401 from atom.model_ops.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, @@ -1430,6 +1430,15 @@ def __init__( ): self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True + # DualRMSNorm for non-triton-GEMM path (fused dual norm + quant) + self.qk_layernorm = DualRMSNorm( + self.q_lora_rank, + self.kv_lora_rank, + eps=config.rms_norm_eps, + quant_config=quant_config, + transpose_scale=True, + shuffle=False, + ) def forward( self, @@ -1477,24 +1486,8 @@ def forward( if self.fuse_qknorm_quant: ( (hidden_states_or_q_c, hidden_states_or_q_c_scale), - _, kv_c_normed, - _, - ) = _fuse_rmsnorm_quant( - q_c, - self.q_a_layernorm.weight, - self.q_a_layernorm.eps, - kv_c, - self.kv_a_layernorm.weight, - self.kv_a_layernorm.eps, - None, - dtype_quant=self.quant_dtype, - shuffle=False, - scale_shuffle_padding=False, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=True, - ) + ) = self.qk_layernorm(q_c, kv_c) else: hidden_states_or_q_c = self.q_a_layernorm(q_c) else: @@ -1617,6 +1610,10 @@ def __init__( fused_allreduce=self.fuse_ar_input_norm and self.layer_idx > 0 and not is_mtp_block, + fused_quant=self.fuse_input_norm_quant, + quant_config=quant_config, + transpose_scale=True, + shuffle=True, ) self.post_attention_layernorm = RMSNorm( config.hidden_size, @@ -1634,57 +1631,14 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> torch.Tensor: - # Self Attention - if self.fuse_input_norm_quant: - assert self.quant_dtype is not None - weight = self.input_layernorm.weight - eps = self.input_layernorm.eps - if residual is None: - residual = hidden_states - (hidden_states_quant, hidden_states_quant_scale), _, _, _ = ( - _fuse_rmsnorm_quant( - hidden_states, - weight, - eps, - None, - None, - None, - None, - dtype_quant=self.quant_dtype, - shuffle=True, - scale_shuffle_padding=True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=True, - ) - ) - else: - (hidden_states_quant, hidden_states_quant_scale), _, _, residual = ( - _fuse_rmsnorm_quant( - hidden_states, - weight, - eps, - None, - None, - None, - residual, - dtype_quant=self.quant_dtype, - shuffle=True, - scale_shuffle_padding=True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=True, - ) - ) - - hidden_states = (hidden_states_quant, hidden_states_quant_scale) - + # Self Attention — unified through RMSNorm.forward() + # fused_quant path returns (quant, scale) or ((quant, scale), residual) + # plain/allreduce path returns tensor or (tensor, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, From 6ac81915d14e956ccc4f23ea377a2c3fe4014e36 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 15 Mar 2026 05:18:43 +0000 Subject: [PATCH 3/5] refactor: add ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch, delete dead code Phase 3+4 of RMSNorm fusion unification: - envs.py: add ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch (default ON) Old per-model vars fallback to master switch when not explicitly set - deepseek_v2.py: delete ~210 lines of private _fuse_rmsnorm_quant functions and fake-tensor versions (now in layernorm.py) - Remove unused self.fuse_rmsnorm_quant flag from DecoderLayer --- atom/models/deepseek_v2.py | 219 ------------------------------------- atom/utils/envs.py | 16 ++- 2 files changed, 13 insertions(+), 222 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 849f781e7..36f062c72 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -43,11 +43,9 @@ from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from aiter.ops.triton.fused_fp8_quant import ( fused_reduce_rms_fp8_group_quant, - fused_rms_fp8_group_quant, ) from aiter.ops.triton.fused_mxfp4_quant import ( fused_reduce_rms_mxfp4_quant, - fused_rms_mxfp4_quant, ) from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from aiter.rotary_embedding import get_rope @@ -118,220 +116,6 @@ ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION -def _fuse_rmsnorm_fp4_quant_fake( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - shuffle: bool = True, - scale_shuffle_padding: bool = True, - output_unquantized_inp1: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - m, n1 = x1.shape - n2 = x2.shape[1] if x2 is not None else 0 - - out1_quantized = torch.empty((m, n1 // 2), dtype=torch.uint8, device=x1.device) - - scale_n_valid = (n1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE - - scale_m = ((m + 255) // 256) * 256 - scale_n = ((scale_n_valid + 7) // 8) * 8 - - out1_bs = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=x1.device) - - out2 = None - if x2 is not None: - out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) - - out_res1 = None - if res1 is not None: - out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) - - out1_unquantized = None - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -def _fused_rms_fp8_group_quant_fake( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - dtype_quant: torch.dtype = dtypes.fp8, - group_size: int = 128, - output_unquantized_inp1: bool = False, - transpose_scale: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - m, n1 = x1.shape - out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device) - out1_bs = torch.empty( - (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=x1.device - ) - if transpose_scale: - out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) - out1_unquantized = None - if output_unquantized_inp1: - out1_unquantized = torch.empty_like(x1) - out2 = None - if x2 is not None: - _, n2 = x2.shape - out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) - out_res1 = None - if res1 is not None: - out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -@torch_compile_guard(gen_fake=_fuse_rmsnorm_fp4_quant_fake) -def _fuse_rmsnorm_fp4_quant( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - shuffle: bool = True, - scale_shuffle_padding: bool = True, - output_unquantized_inp1: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - m = x1.shape[0] - - shuffle_bool = shuffle and (m >= MXFP4_QUANT_BLOCK_SIZE) - - (out1_quantized, out1_bs), _out1_unquantized, out2, out_res1 = ( - fused_rms_mxfp4_quant( - x1=x1, - x1_weight=x1_weight, - x1_epsilon=x1_epsilon, - x2=x2, - x2_weight=x2_weight, - x2_epsilon=0.0 if x2_epsilon is None else x2_epsilon, - res1=res1, - shuffle=shuffle_bool, - scale_shuffle_padding=scale_shuffle_padding, - output_unquantized_inp1=output_unquantized_inp1, - ) - ) - - out1_unquantized = None - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -@torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake) -def _fused_rms_fp8_group_quant( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - dtype_quant: torch.dtype = dtypes.fp8, - group_size: int = 128, - output_unquantized_inp1: bool = False, - transpose_scale: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = ( - fused_rms_fp8_group_quant( - x1, - x1_weight, - x1_epsilon, - x2, - x2_weight, - x2_epsilon, - group_size, - dtype_quant, - res1, - output_unquantized_inp1, - transpose_scale, - ) - ) - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -def _fuse_rmsnorm_quant( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - dtype_quant: torch.dtype = dtypes.fp8, - shuffle: bool = True, - scale_shuffle_padding: bool = False, - group_size: int = 128, - output_unquantized_inp1: bool = False, - transpose_scale: bool = False, -): - if dtype_quant == dtypes.fp4x2: - out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( - _fuse_rmsnorm_fp4_quant( - x1, - x1_weight, - x1_epsilon, - x2, - x2_weight, - x2_epsilon, - res1, - shuffle, - scale_shuffle_padding, - output_unquantized_inp1, - ) - ) - elif dtype_quant == dtypes.fp8: - out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( - _fused_rms_fp8_group_quant( - x1, - x1_weight, - x1_epsilon, - x2, - x2_weight, - x2_epsilon, - res1, - dtype_quant, - group_size, - output_unquantized_inp1, - transpose_scale, - ) - ) - else: - raise ValueError( - f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}." - ) - return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 - - def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake( hidden_states_quant: torch.Tensor, weight_qkv_a_proj: torch.Tensor, @@ -1621,9 +1405,6 @@ def __init__( fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) self.routed_scaling_factor = config.routed_scaling_factor - self.fuse_rmsnorm_quant = ( - ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None - ) def forward( self, diff --git a/atom/utils/envs.py b/atom/utils/envs.py index ea543b5f7..3531dd633 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -39,20 +39,30 @@ "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION", "0" ) == "1", + # Master switch for RMSNorm + quantization fusion (all models) + "ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv( + "ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1" + ) + == "1", + # Deprecated: use ATOM_ENABLE_RMSNORM_QUANT_FUSION instead (fallback to master switch) "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION": lambda: os.getenv( - "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", "1" + "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", + os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), ) == "1", "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION": lambda: os.getenv( - "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", "1" + "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", + os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), ) == "1", "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: os.getenv( "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1" ) == "1", + # Deprecated: use ATOM_ENABLE_RMSNORM_QUANT_FUSION instead "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT": lambda: os.getenv( - "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT", "1" + "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT", + os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), ) == "1", "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT": lambda: os.getenv( From a19ec78e339e20c9e044ce69ff1c2608183f92d4 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 15 Mar 2026 07:23:33 +0000 Subject: [PATCH 4/5] fix: DualRMSNorm references existing norms instead of creating own weights DualRMSNorm previously created weight1/weight2 parameters that didn't match checkpoint keys (qk_layernorm.weight1 vs q_a_layernorm.weight). This caused weights to stay at initial 1.0 values, producing garbage output (GSM8K 0.0). Fix: DualRMSNorm now takes existing RMSNorm modules as constructor arguments and uses their .weight and .eps directly. No duplicate parameters, checkpoint loading works correctly. --- atom/model_ops/layernorm.py | 23 +++++++++++------------ atom/models/deepseek_v2.py | 5 ++--- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 6c2501255..b144d5028 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -552,23 +552,22 @@ class DualRMSNorm(nn.Module): Uses a single AITER kernel call to normalize both tensors and quantize the first, reducing kernel launch overhead vs two separate RMSNorm calls. Typically used in MLA attention for the q_a + kv_a layernorms. + + Does NOT own weight parameters — references existing RMSNorm modules' + weights so that checkpoint loading works correctly. """ def __init__( self, - dim1: int, - dim2: int, - eps: float = 1e-6, + norm1: RMSNorm, + norm2: RMSNorm, quant_config: Optional[QuantizationConfig] = None, transpose_scale: bool = False, shuffle: bool = False, ) -> None: super().__init__() - self.dim1 = dim1 - self.dim2 = dim2 - self.eps = eps - self.weight1 = nn.Parameter(torch.ones(dim1)) - self.weight2 = nn.Parameter(torch.ones(dim2)) + self.norm1 = norm1 + self.norm2 = norm2 self.transpose_scale = transpose_scale self.shuffle = shuffle @@ -590,11 +589,11 @@ def forward( """ (x1_quant, x1_scale), _, x2_normed, _ = fuse_rmsnorm_group_quant( x1, - self.weight1, - self.eps, + self.norm1.weight, + self.norm1.eps, x2=x2, - x2_weight=self.weight2, - x2_epsilon=self.eps, + x2_weight=self.norm2.weight, + x2_epsilon=self.norm2.eps, dtype_quant=self.params_dtype, shuffle=self.shuffle, group_size=128, diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 36f062c72..2f0cd47e4 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1216,9 +1216,8 @@ def __init__( self.fuse_qknorm_quant = True # DualRMSNorm for non-triton-GEMM path (fused dual norm + quant) self.qk_layernorm = DualRMSNorm( - self.q_lora_rank, - self.kv_lora_rank, - eps=config.rms_norm_eps, + self.q_a_layernorm, + self.kv_a_layernorm, quant_config=quant_config, transpose_scale=True, shuffle=False, From b9f6f6321e2c3c6c7d03f3930b7b3247c2d5bd17 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 15 Mar 2026 15:35:32 +0000 Subject: [PATCH 5/5] refactor: enable DualRMSNorm for FP4 non-triton path, update docs - Remove use_triton_gemm() guard for DualRMSNorm: FP4 models now use fused QK-norm + quant via DualRMSNorm regardless of GEMM backend - Update comments/docs to match: clarify input_layernorm still requires triton GEMM while QK-norm DualRMSNorm does not - Add DualRMSNorm to model_ops_guide (source table, normalization section, fused kernel chains table) - Add ATOM_ENABLE_RMSNORM_QUANT_FUSION master switch to env var docs, mark old per-model vars as deprecated - Fix DualRMSNorm docstring terminology consistency Verified: DeepSeek BF16 GSM8K 0.957, DeepSeek MXFP4 GSM8K 0.948 --- atom/model_ops/layernorm.py | 5 +++-- atom/models/deepseek_v2.py | 17 +++++++---------- atom/utils/envs.py | 3 ++- docs/environment_variables.md | 7 ++++--- docs/model_ops_guide.md | 21 +++++++++++++++++++-- 5 files changed, 35 insertions(+), 18 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index b144d5028..a115623cf 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -547,11 +547,12 @@ def forward( class DualRMSNorm(nn.Module): - """Fused dual RMSNorm + quantization for two inputs (e.g. q_c + kv_c). + """Fused dual RMSNorm + quantization for two inputs. Uses a single AITER kernel call to normalize both tensors and quantize the first, reducing kernel launch overhead vs two separate RMSNorm calls. - Typically used in MLA attention for the q_a + kv_a layernorms. + In MLA attention, normalizes q_c via q_a_layernorm and kv_c via + kv_a_layernorm in one fused call. Does NOT own weight parameters — references existing RMSNorm modules' weights so that checkpoint loading works correctly. diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 2f0cd47e4..9f3f550f3 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1204,17 +1204,14 @@ def __init__( prefix=prefix, ) - # When ATOM_ENABLE_DS_QKNORM_QUANT_FUSION is turned on, self.fuse_qknorm_quant is turned on only if FP8 or (use_triton_gemm() and FP4), self.prefix = prefix self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if layer_quant_dtype == dtypes.fp8 or ( - layer_quant_dtype == dtypes.fp4x2 and use_triton_gemm() - ): + if layer_quant_dtype in (dtypes.fp8, dtypes.fp4x2): self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True - # DualRMSNorm for non-triton-GEMM path (fused dual norm + quant) + # DualRMSNorm: fused dual norm + quant for both FP8 and FP4 paths self.qk_layernorm = DualRMSNorm( self.q_a_layernorm, self.kv_a_layernorm, @@ -1338,11 +1335,11 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) - # When ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on self.fuse_input_norm_quant is turned on only if use_triton_gemm and (FP8 or FP4), - # Because AR_RMS and RMS_Quant cannot co-exist for input_layernorm, this block of codes ensures 3 things when ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on: + # input_layernorm fusion requires use_triton_gemm() (unlike QK-norm DualRMSNorm which works without it). + # Because AR_RMS and RMS_Quant cannot co-exist for input_layernorm, this block ensures: # 1. RMS_Quant fusion is only used for input_layernorm - # 2. The reduce_results variable is re-enabled for feed forward layers (MOE and MLP), because AR_RMS is now disabled in the beginning of the next layer - # 3. AR_RMS is turned off for input_layernorm but still enabled for post_attention_layernorm if ENABLE_ALLREDUCE_RMSNORM_FUSION is turned on + # 2. reduce_results is re-enabled for feed forward layers (MOE/MLP), since AR_RMS is disabled at the start of the next layer + # 3. AR_RMS is turned off for input_layernorm but still enabled for post_attention_layernorm self.quant_dtype = ( None if quant_config is None @@ -1364,7 +1361,7 @@ def __init__( else: if layer_idx == 0: logger.info( - "Info: Because ATOM_USE_TRITON_GEMM is not turned on in DeepSeek-R1, ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned off automatically" + "Info: ATOM_USE_TRITON_GEMM is off, input_layernorm RMSNorm+quant fusion disabled (QK-norm DualRMSNorm fusion still active)" ) if ( diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 3531dd633..c8acd66c7 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -44,12 +44,13 @@ "ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1" ) == "1", - # Deprecated: use ATOM_ENABLE_RMSNORM_QUANT_FUSION instead (fallback to master switch) + # Deprecated: falls back to ATOM_ENABLE_RMSNORM_QUANT_FUSION when unset "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION": lambda: os.getenv( "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), ) == "1", + # Deprecated: falls back to ATOM_ENABLE_RMSNORM_QUANT_FUSION when unset "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION": lambda: os.getenv( "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), diff --git a/docs/environment_variables.md b/docs/environment_variables.md index ecd9e9bbb..6647b0bdc 100644 --- a/docs/environment_variables.md +++ b/docs/environment_variables.md @@ -50,12 +50,13 @@ This document describes the environment variables used in the ATOM project. |----------|------|---------|-------------| | **ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION** | bool | 1 (true) | If set to `1`, fuse allreduce with RMSNorm in tensor parallel mode. | -### DeepSeek-style +### RMSNorm + Quantization Fusion | Variable | Type | Default | Description | |----------|------|---------|-------------| -| **ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION** | bool | 1 (true) | If set to `1`, fuse RMSNorm with quantization. | -| **ATOM_ENABLE_DS_QKNORM_QUANT_FUSION** | bool | 1 (true) | If set to `1`, fuse QK norm with quantization in MLA attention module. | +| **ATOM_ENABLE_RMSNORM_QUANT_FUSION** | bool | 1 (true) | Master switch for all RMSNorm + quantization fusion paths (all models). | +| **ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek input layernorm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. | +| **ATOM_ENABLE_DS_QKNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek QK-norm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. | ### Qwen3-MoE style diff --git a/docs/model_ops_guide.md b/docs/model_ops_guide.md index 969082719..044f04d12 100644 --- a/docs/model_ops_guide.md +++ b/docs/model_ops_guide.md @@ -18,6 +18,7 @@ ATOM (AiTer Optimized Model) wraps AITER kernels with model-level abstractions f | `MLAAttention` | `attention_mla.py` | `mla_decode_fwd`, `mla_prefill_fwd`, `concat_and_cache_mla`, `fused_qk_rope_concat_and_cache_mla` | Multi-head latent attention | | `FusedMoE` | `moe.py` | `aiter.fused_moe.fused_moe`, `asm_moe` | Mixture of experts | | `RMSNorm` | `layernorm.py` | `rmsnorm2d_fwd`, `rmsnorm2d_fwd_with_add`, `fused_add_rmsnorm_pad` | RMS normalization | +| `DualRMSNorm` | `layernorm.py` | `fuse_rmsnorm_group_quant` | Fused dual RMSNorm + quant (MLA q/kv norms) | | `LayerNorm` | `layernorm.py` | `layernorm2d_fwd`, `layernorm2d_fwd_with_add` | Layer normalization | | `SiluAndMul` | `activation.py` | `aiter.silu_and_mul` | SiLU gated activation | | `VocabParallelEmbedding` | `embed_head.py` | `F.embedding` + TP all-reduce | Vocab embedding | @@ -353,7 +354,22 @@ RMSNorm( ) ``` -### 5.2 `LayerNorm` (`layernorm.py`) +### 5.2 `DualRMSNorm` (`layernorm.py`) + +`DualRMSNorm` fuses normalization of two tensors (e.g. q_c and kv_c in MLA) and quantization of the first into a single `fuse_rmsnorm_group_quant` kernel call. It references existing `RMSNorm` modules' weights rather than creating its own, so checkpoint loading works correctly. + +```python +DualRMSNorm( + norm1: RMSNorm, # e.g. q_a_layernorm + norm2: RMSNorm, # e.g. kv_a_layernorm + quant_config: Optional[QuantizationConfig] = None, + transpose_scale: bool = False, + shuffle: bool = False, +) +# forward(x1, x2) -> ((x1_quant, x1_scale), x2_normed) +``` + +### 5.3 `LayerNorm` (`layernorm.py`) `LayerNorm` wraps `layernorm2d_fwd` and `layernorm2d_fwd_with_add` (with bias support): @@ -482,6 +498,7 @@ ATOM uses fused kernels to reduce memory traffic by combining multiple operation |---|---|---|---| | RMSNorm + FP8 quant | RMSNorm, per-tensor FP8 static quant | `RMSNorm(fused_quant=True)` + `x_scale` | `fused_rms_fp8_per_tensor_static_quant` | | RMSNorm + MXFP4 quant | RMSNorm, per-1x32 MXFP4 quant | `RMSNorm(fused_quant=True)` + `QuantType.per_1x32` | `fused_rms_mxfp4_quant` | +| Dual RMSNorm + quant | Normalize q_c + kv_c, quantize q_c | `DualRMSNorm` (MLA, FP8/FP4) | `fuse_rmsnorm_group_quant` | | RMSNorm + add + pad | Residual add, RMSNorm, output padding | `RMSNorm(x_pad_to_multiple>0)` | `fused_add_rmsnorm_pad` | | AllReduce + RMSNorm | TP all-reduce, RMSNorm | `RMSNorm(fused_allreduce=True)` | `tensor_model_parallel_fused_allreduce_rmsnorm` | | SiLU + mul + FP8 quant | SiLU activation, multiply, FP8 quant | `SiluAndMul(fused_quant=True)` + `x_scale` | `fused_silu_mul_fp8_per_tensor_static_quant` | @@ -504,7 +521,7 @@ ATOM uses fused kernels to reduce memory traffic by combining multiple operation |---|---| | `linear.py` | `LinearBase`, `ColumnParallelLinear`, `RowParallelLinear`, `QKVParallelLinear`, `MergedColumnParallelLinear`, `ReplicatedLinear`, `MergedReplicatedLinear` | | `activation.py` | `SiluAndMul` with fused FP8/MXFP4 quantization | -| `layernorm.py` | `RMSNorm`, `LayerNorm` with fused allreduce/quant/pad variants | +| `layernorm.py` | `RMSNorm`, `DualRMSNorm`, `LayerNorm` with fused allreduce/quant/pad variants; group-quant dispatch (`fuse_rmsnorm_group_quant`) | | `base_attention.py` | Top-level `Attention` dispatcher with custom op registration | | `attention_mha.py` | MHA implementation: prefill (flash), decode (ASM/Triton paged attention) | | `attention_mla.py` | `MLAAttention`, `MLAModules` -- DeepSeek MLA with compressed KV |