diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 7f556fcc5..a115623cf 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -25,6 +25,12 @@ from aiter import ( QuantType, ) +from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE + +try: + from aiter import dtypes as _aiter_dtypes +except ImportError: + _aiter_dtypes = None def silu(input: Tensor, inplace: bool = False) -> Tensor: @@ -169,6 +175,216 @@ def mxfp4_rms_quant_fuse( return x_quant, x_scale, residual_out +# --------------------------------------------------------------------------- +# Group-quant fused RMSNorm kernels (moved from deepseek_v2.py) +# --------------------------------------------------------------------------- + + +def _fuse_rmsnorm_fp4_quant_fake( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + shuffle: bool = True, + scale_shuffle_padding: bool = True, + output_unquantized_inp1: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + m, n1 = x1.shape + n2 = x2.shape[1] if x2 is not None else 0 + + out1_quantized = torch.empty((m, n1 // 2), dtype=torch.uint8, device=x1.device) + scale_n_valid = (n1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + scale_m = ((m + 255) // 256) * 256 + scale_n = ((scale_n_valid + 7) // 8) * 8 + out1_bs = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=x1.device) + + out2 = None + if x2 is not None: + out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) + out_res1 = None + if res1 is not None: + out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) + out1_unquantized = None + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +def _fused_rms_fp8_group_quant_fake( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: "torch.dtype | None" = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if dtype_quant is None and _aiter_dtypes is not None: + dtype_quant = _aiter_dtypes.fp8 + m, n1 = x1.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device) + out1_bs = torch.empty( + (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=x1.device + ) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out1_unquantized = None + if output_unquantized_inp1: + out1_unquantized = torch.empty_like(x1) + out2 = None + if x2 is not None: + _, n2 = x2.shape + out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) + out_res1 = None + if res1 is not None: + out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +@torch_compile_guard(gen_fake=_fuse_rmsnorm_fp4_quant_fake) +def _fuse_rmsnorm_fp4_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + shuffle: bool = True, + scale_shuffle_padding: bool = True, + output_unquantized_inp1: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + + m = x1.shape[0] + shuffle_bool = shuffle and (m >= MXFP4_QUANT_BLOCK_SIZE) + + (out1_quantized, out1_bs), _out1_unquantized, out2, out_res1 = ( + fused_rms_mxfp4_quant( + x1=x1, + x1_weight=x1_weight, + x1_epsilon=x1_epsilon, + x2=x2, + x2_weight=x2_weight, + x2_epsilon=0.0 if x2_epsilon is None else x2_epsilon, + res1=res1, + shuffle=shuffle_bool, + scale_shuffle_padding=scale_shuffle_padding, + output_unquantized_inp1=output_unquantized_inp1, + ) + ) + out1_unquantized = None + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +@torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake) +def _fused_rms_fp8_group_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: "torch.dtype | None" = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + if dtype_quant is None and _aiter_dtypes is not None: + dtype_quant = _aiter_dtypes.fp8 + + (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = ( + fused_rms_fp8_group_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + group_size, + dtype_quant, + res1, + output_unquantized_inp1, + transpose_scale, + ) + ) + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +def fuse_rmsnorm_group_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: "torch.dtype | None" = None, + shuffle: bool = True, + scale_shuffle_padding: bool = False, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +): + """Dispatch fused RMSNorm + group quantization to the correct kernel. + + Supports FP8 per-group and MXFP4 quantization. Optionally normalizes a + second tensor (x2) and fuses residual-add (res1) in the same kernel call. + + Returns: (out1_quantized, out1_scale), out1_unquantized, out2_normed, residual_out + """ + if _aiter_dtypes is None: + raise RuntimeError("aiter.dtypes not available") + + if dtype_quant is None: + dtype_quant = _aiter_dtypes.fp8 + + if dtype_quant == _aiter_dtypes.fp4x2: + out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( + _fuse_rmsnorm_fp4_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + res1, + shuffle, + scale_shuffle_padding, + output_unquantized_inp1, + ) + ) + elif dtype_quant == _aiter_dtypes.fp8: + out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( + _fused_rms_fp8_group_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + res1, + dtype_quant, + group_size, + output_unquantized_inp1, + transpose_scale, + ) + ) + else: + raise ValueError( + f"No fused rmsnorm quant kernel available for quant dtype: {dtype_quant}." + ) + return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 + + class RMSNorm(nn.Module): def __init__( self, @@ -178,6 +394,8 @@ def __init__( fused_allreduce: bool = False, fused_quant: bool = False, quant_config: Optional[QuantizationConfig] = None, + transpose_scale: bool = False, + shuffle: bool = True, ) -> None: super().__init__() self.dim = dim @@ -187,6 +405,8 @@ def __init__( self.fused_allreduce = fused_allreduce self.use_fused_quant = fused_quant self.tp_size = get_tensor_model_parallel_world_size() + self.transpose_scale = transpose_scale + self.shuffle = shuffle layer_quant_config = ( LayerQuantConfig() @@ -198,6 +418,27 @@ def __init__( self.quant_type = quant_type self.params_dtype = params_dtype + # Determine the fused quant path based on quant_config: + # - "group": FP8 per-group (per_1x128/per_Token) or MXFP4 group (fp4x2) + # - "per_tensor": FP8 per-tensor static (requires x_scale at forward time) + # - "simple_mxfp4": MXFP4 simple (per_1x32 without group quant) + # - None: no quantization fusion + self._quant_path = None + if fused_quant and _aiter_dtypes is not None: + if params_dtype == _aiter_dtypes.fp8 and quant_type in ( + QuantType.per_1x128, + QuantType.per_Token, + ): + self._quant_path = "group" + elif ( + params_dtype == _aiter_dtypes.fp8 and quant_type == QuantType.per_Tensor + ): + self._quant_path = "per_tensor" + elif params_dtype == _aiter_dtypes.fp4x2: + self._quant_path = "group" + elif quant_type == QuantType.per_1x32: + self._quant_path = "simple_mxfp4" + @mark_trace(prefix="rmsnorm", torch_compile=True) def forward( self, @@ -228,66 +469,138 @@ def forward( self.eps, ) return x, residual - else: - if x_scale is not None and self.use_fused_quant: - from aiter.ops.triton.fused_fp8_quant import ( - fused_rms_fp8_per_tensor_static_quant, + + # --- Fused quant paths (dispatched by _quant_path) --- + if self._quant_path == "group": + # FP8 per-group or MXFP4 group quantization + (x_quant, x_scale), _, _, res_out = fuse_rmsnorm_group_quant( + x, + self.weight, + self.eps, + res1=residual, + dtype_quant=self.params_dtype, + shuffle=self.shuffle, + scale_shuffle_padding=self.shuffle, + group_size=128, + transpose_scale=self.transpose_scale, + ) + if residual is None: + return (x_quant, x_scale) + else: + return (x_quant, x_scale), res_out + + if self._quant_path == "per_tensor" and x_scale is not None: + from aiter.ops.triton.fused_fp8_quant import ( + fused_rms_fp8_per_tensor_static_quant, + ) + import aiter as rocm_aiter + + rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 + if residual is None: + x, _, _, _ = fused_rms_fp8_per_tensor_static_quant( + x, + self.weight, + self.eps, + x_scale, + None, + None, + self.eps, + dtype_quant=rocm_aiter_fp8_dtype, + res1=None, ) - import aiter as rocm_aiter - - rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 - - # static FP8 quantization - if residual is None: - x, _, _, _ = fused_rms_fp8_per_tensor_static_quant( - x, - self.weight, - self.eps, - x_scale, - None, - None, - self.eps, - dtype_quant=rocm_aiter_fp8_dtype, - res1=None, - ) - return (x, x_scale) - else: - x, _, _, residual = fused_rms_fp8_per_tensor_static_quant( - x, - self.weight, - self.eps, - x_scale, - None, - None, - self.eps, - dtype_quant=rocm_aiter_fp8_dtype, - res1=residual, - ) - return (x, x_scale), residual - elif self.use_fused_quant and ( - x_scale is None and self.quant_type.value == QuantType.per_1x32.value - ): - if residual is None: - x, x_scale, _ = mxfp4_rms_quant_fuse( - x, self.weight, self.eps, shuffle=True - ) - return x, x_scale - else: - x, x_scale, residual = mxfp4_rms_quant_fuse( - x, self.weight, self.eps, shuffle=True, res1=residual - ) - return (x, x_scale), residual + return (x, x_scale) + else: + x, _, _, residual = fused_rms_fp8_per_tensor_static_quant( + x, + self.weight, + self.eps, + x_scale, + None, + None, + self.eps, + dtype_quant=rocm_aiter_fp8_dtype, + res1=residual, + ) + return (x, x_scale), residual + + if self._quant_path == "simple_mxfp4": + if residual is None: + x, x_scale, _ = mxfp4_rms_quant_fuse( + x, self.weight, self.eps, shuffle=self.shuffle + ) + return x, x_scale else: - if residual is None: - # return rmsnorm2d_fwd(x, self.weight, self.eps).view(ori_shape) - x = rmsnorm2d_fwd_(x, self.weight, self.eps, self.dim) - return x - else: - # return self.add_rms_forward(x, residual) - x, residual = rmsnorm2d_fwd_with_add_( - x, self.weight, residual, self.eps, self.dim - ) - return x, residual + x, x_scale, residual = mxfp4_rms_quant_fuse( + x, self.weight, self.eps, shuffle=self.shuffle, res1=residual + ) + return (x, x_scale), residual + + # --- Plain RMSNorm (no fusion) --- + if residual is None: + x = rmsnorm2d_fwd_(x, self.weight, self.eps, self.dim) + return x + else: + x, residual = rmsnorm2d_fwd_with_add_( + x, self.weight, residual, self.eps, self.dim + ) + return x, residual + + +class DualRMSNorm(nn.Module): + """Fused dual RMSNorm + quantization for two inputs. + + Uses a single AITER kernel call to normalize both tensors and quantize + the first, reducing kernel launch overhead vs two separate RMSNorm calls. + In MLA attention, normalizes q_c via q_a_layernorm and kv_c via + kv_a_layernorm in one fused call. + + Does NOT own weight parameters — references existing RMSNorm modules' + weights so that checkpoint loading works correctly. + """ + + def __init__( + self, + norm1: RMSNorm, + norm2: RMSNorm, + quant_config: Optional[QuantizationConfig] = None, + transpose_scale: bool = False, + shuffle: bool = False, + ) -> None: + super().__init__() + self.norm1 = norm1 + self.norm2 = norm2 + self.transpose_scale = transpose_scale + self.shuffle = shuffle + + layer_quant_config = ( + LayerQuantConfig() + if quant_config is None + else quant_config.global_quant_config + ) + self.params_dtype = layer_quant_config["quant_dtype"] + + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + ) -> tuple: + """Normalize x1 and x2, quantize x1. + + Returns: (x1_quant, x1_scale), x2_normed + """ + (x1_quant, x1_scale), _, x2_normed, _ = fuse_rmsnorm_group_quant( + x1, + self.norm1.weight, + self.norm1.eps, + x2=x2, + x2_weight=self.norm2.weight, + x2_epsilon=self.norm2.eps, + dtype_quant=self.params_dtype, + shuffle=self.shuffle, + group_size=128, + transpose_scale=self.transpose_scale, + ) + return (x1_quant, x1_scale), x2_normed class RMSNormGated(nn.Module): diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d1da9f052..9f3f550f3 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -43,11 +43,9 @@ from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from aiter.ops.triton.fused_fp8_quant import ( fused_reduce_rms_fp8_group_quant, - fused_rms_fp8_group_quant, ) from aiter.ops.triton.fused_mxfp4_quant import ( fused_reduce_rms_mxfp4_quant, - fused_rms_mxfp4_quant, ) from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from aiter.rotary_embedding import get_rope @@ -61,7 +59,7 @@ from atom.model_ops.attention_mla import MLAModules, is_rocm_aiter_fp4bmm_enabled from atom.model_ops.base_attention import Attention from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding -from atom.model_ops.layernorm import LayerNorm, RMSNorm +from atom.model_ops.layernorm import DualRMSNorm, LayerNorm, RMSNorm # noqa: F401 from atom.model_ops.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, @@ -118,220 +116,6 @@ ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION -def _fuse_rmsnorm_fp4_quant_fake( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - shuffle: bool = True, - scale_shuffle_padding: bool = True, - output_unquantized_inp1: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - m, n1 = x1.shape - n2 = x2.shape[1] if x2 is not None else 0 - - out1_quantized = torch.empty((m, n1 // 2), dtype=torch.uint8, device=x1.device) - - scale_n_valid = (n1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE - - scale_m = ((m + 255) // 256) * 256 - scale_n = ((scale_n_valid + 7) // 8) * 8 - - out1_bs = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=x1.device) - - out2 = None - if x2 is not None: - out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) - - out_res1 = None - if res1 is not None: - out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) - - out1_unquantized = None - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -def _fused_rms_fp8_group_quant_fake( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - dtype_quant: torch.dtype = dtypes.fp8, - group_size: int = 128, - output_unquantized_inp1: bool = False, - transpose_scale: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - m, n1 = x1.shape - out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device) - out1_bs = torch.empty( - (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=x1.device - ) - if transpose_scale: - out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) - out1_unquantized = None - if output_unquantized_inp1: - out1_unquantized = torch.empty_like(x1) - out2 = None - if x2 is not None: - _, n2 = x2.shape - out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) - out_res1 = None - if res1 is not None: - out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -@torch_compile_guard(gen_fake=_fuse_rmsnorm_fp4_quant_fake) -def _fuse_rmsnorm_fp4_quant( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - shuffle: bool = True, - scale_shuffle_padding: bool = True, - output_unquantized_inp1: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - m = x1.shape[0] - - shuffle_bool = shuffle and (m >= MXFP4_QUANT_BLOCK_SIZE) - - (out1_quantized, out1_bs), _out1_unquantized, out2, out_res1 = ( - fused_rms_mxfp4_quant( - x1=x1, - x1_weight=x1_weight, - x1_epsilon=x1_epsilon, - x2=x2, - x2_weight=x2_weight, - x2_epsilon=0.0 if x2_epsilon is None else x2_epsilon, - res1=res1, - shuffle=shuffle_bool, - scale_shuffle_padding=scale_shuffle_padding, - output_unquantized_inp1=output_unquantized_inp1, - ) - ) - - out1_unquantized = None - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -@torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake) -def _fused_rms_fp8_group_quant( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - dtype_quant: torch.dtype = dtypes.fp8, - group_size: int = 128, - output_unquantized_inp1: bool = False, - transpose_scale: bool = False, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = ( - fused_rms_fp8_group_quant( - x1, - x1_weight, - x1_epsilon, - x2, - x2_weight, - x2_epsilon, - group_size, - dtype_quant, - res1, - output_unquantized_inp1, - transpose_scale, - ) - ) - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 - - -def _fuse_rmsnorm_quant( - x1: torch.Tensor, - x1_weight: torch.Tensor, - x1_epsilon: float, - x2: Optional[torch.Tensor] = None, - x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: Optional[float] = None, - res1: Optional[torch.Tensor] = None, - dtype_quant: torch.dtype = dtypes.fp8, - shuffle: bool = True, - scale_shuffle_padding: bool = False, - group_size: int = 128, - output_unquantized_inp1: bool = False, - transpose_scale: bool = False, -): - if dtype_quant == dtypes.fp4x2: - out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( - _fuse_rmsnorm_fp4_quant( - x1, - x1_weight, - x1_epsilon, - x2, - x2_weight, - x2_epsilon, - res1, - shuffle, - scale_shuffle_padding, - output_unquantized_inp1, - ) - ) - elif dtype_quant == dtypes.fp8: - out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = ( - _fused_rms_fp8_group_quant( - x1, - x1_weight, - x1_epsilon, - x2, - x2_weight, - x2_epsilon, - res1, - dtype_quant, - group_size, - output_unquantized_inp1, - transpose_scale, - ) - ) - else: - raise ValueError( - f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}." - ) - return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 - - def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake( hidden_states_quant: torch.Tensor, weight_qkv_a_proj: torch.Tensor, @@ -1420,16 +1204,21 @@ def __init__( prefix=prefix, ) - # When ATOM_ENABLE_DS_QKNORM_QUANT_FUSION is turned on, self.fuse_qknorm_quant is turned on only if FP8 or (use_triton_gemm() and FP4), self.prefix = prefix self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if layer_quant_dtype == dtypes.fp8 or ( - layer_quant_dtype == dtypes.fp4x2 and use_triton_gemm() - ): + if layer_quant_dtype in (dtypes.fp8, dtypes.fp4x2): self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True + # DualRMSNorm: fused dual norm + quant for both FP8 and FP4 paths + self.qk_layernorm = DualRMSNorm( + self.q_a_layernorm, + self.kv_a_layernorm, + quant_config=quant_config, + transpose_scale=True, + shuffle=False, + ) def forward( self, @@ -1477,24 +1266,8 @@ def forward( if self.fuse_qknorm_quant: ( (hidden_states_or_q_c, hidden_states_or_q_c_scale), - _, kv_c_normed, - _, - ) = _fuse_rmsnorm_quant( - q_c, - self.q_a_layernorm.weight, - self.q_a_layernorm.eps, - kv_c, - self.kv_a_layernorm.weight, - self.kv_a_layernorm.eps, - None, - dtype_quant=self.quant_dtype, - shuffle=False, - scale_shuffle_padding=False, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=True, - ) + ) = self.qk_layernorm(q_c, kv_c) else: hidden_states_or_q_c = self.q_a_layernorm(q_c) else: @@ -1562,11 +1335,11 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) - # When ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on self.fuse_input_norm_quant is turned on only if use_triton_gemm and (FP8 or FP4), - # Because AR_RMS and RMS_Quant cannot co-exist for input_layernorm, this block of codes ensures 3 things when ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on: + # input_layernorm fusion requires use_triton_gemm() (unlike QK-norm DualRMSNorm which works without it). + # Because AR_RMS and RMS_Quant cannot co-exist for input_layernorm, this block ensures: # 1. RMS_Quant fusion is only used for input_layernorm - # 2. The reduce_results variable is re-enabled for feed forward layers (MOE and MLP), because AR_RMS is now disabled in the beginning of the next layer - # 3. AR_RMS is turned off for input_layernorm but still enabled for post_attention_layernorm if ENABLE_ALLREDUCE_RMSNORM_FUSION is turned on + # 2. reduce_results is re-enabled for feed forward layers (MOE/MLP), since AR_RMS is disabled at the start of the next layer + # 3. AR_RMS is turned off for input_layernorm but still enabled for post_attention_layernorm self.quant_dtype = ( None if quant_config is None @@ -1588,7 +1361,7 @@ def __init__( else: if layer_idx == 0: logger.info( - "Info: Because ATOM_USE_TRITON_GEMM is not turned on in DeepSeek-R1, ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned off automatically" + "Info: ATOM_USE_TRITON_GEMM is off, input_layernorm RMSNorm+quant fusion disabled (QK-norm DualRMSNorm fusion still active)" ) if ( @@ -1617,6 +1390,10 @@ def __init__( fused_allreduce=self.fuse_ar_input_norm and self.layer_idx > 0 and not is_mtp_block, + fused_quant=self.fuse_input_norm_quant, + quant_config=quant_config, + transpose_scale=True, + shuffle=True, ) self.post_attention_layernorm = RMSNorm( config.hidden_size, @@ -1624,9 +1401,6 @@ def __init__( fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) self.routed_scaling_factor = config.routed_scaling_factor - self.fuse_rmsnorm_quant = ( - ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None - ) def forward( self, @@ -1634,57 +1408,14 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> torch.Tensor: - # Self Attention - if self.fuse_input_norm_quant: - assert self.quant_dtype is not None - weight = self.input_layernorm.weight - eps = self.input_layernorm.eps - if residual is None: - residual = hidden_states - (hidden_states_quant, hidden_states_quant_scale), _, _, _ = ( - _fuse_rmsnorm_quant( - hidden_states, - weight, - eps, - None, - None, - None, - None, - dtype_quant=self.quant_dtype, - shuffle=True, - scale_shuffle_padding=True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=True, - ) - ) - else: - (hidden_states_quant, hidden_states_quant_scale), _, _, residual = ( - _fuse_rmsnorm_quant( - hidden_states, - weight, - eps, - None, - None, - None, - residual, - dtype_quant=self.quant_dtype, - shuffle=True, - scale_shuffle_padding=True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=True, - ) - ) - - hidden_states = (hidden_states_quant, hidden_states_quant_scale) - + # Self Attention — unified through RMSNorm.forward() + # fused_quant path returns (quant, scale) or ((quant, scale), residual) + # plain/allreduce path returns tensor or (tensor, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, diff --git a/atom/utils/envs.py b/atom/utils/envs.py index ea543b5f7..c8acd66c7 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -39,20 +39,31 @@ "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION", "0" ) == "1", + # Master switch for RMSNorm + quantization fusion (all models) + "ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv( + "ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1" + ) + == "1", + # Deprecated: falls back to ATOM_ENABLE_RMSNORM_QUANT_FUSION when unset "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION": lambda: os.getenv( - "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", "1" + "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", + os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), ) == "1", + # Deprecated: falls back to ATOM_ENABLE_RMSNORM_QUANT_FUSION when unset "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION": lambda: os.getenv( - "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", "1" + "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", + os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), ) == "1", "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: os.getenv( "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1" ) == "1", + # Deprecated: use ATOM_ENABLE_RMSNORM_QUANT_FUSION instead "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT": lambda: os.getenv( - "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT", "1" + "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT", + os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1"), ) == "1", "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT": lambda: os.getenv( diff --git a/docs/environment_variables.md b/docs/environment_variables.md index ecd9e9bbb..6647b0bdc 100644 --- a/docs/environment_variables.md +++ b/docs/environment_variables.md @@ -50,12 +50,13 @@ This document describes the environment variables used in the ATOM project. |----------|------|---------|-------------| | **ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION** | bool | 1 (true) | If set to `1`, fuse allreduce with RMSNorm in tensor parallel mode. | -### DeepSeek-style +### RMSNorm + Quantization Fusion | Variable | Type | Default | Description | |----------|------|---------|-------------| -| **ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION** | bool | 1 (true) | If set to `1`, fuse RMSNorm with quantization. | -| **ATOM_ENABLE_DS_QKNORM_QUANT_FUSION** | bool | 1 (true) | If set to `1`, fuse QK norm with quantization in MLA attention module. | +| **ATOM_ENABLE_RMSNORM_QUANT_FUSION** | bool | 1 (true) | Master switch for all RMSNorm + quantization fusion paths (all models). | +| **ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek input layernorm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. | +| **ATOM_ENABLE_DS_QKNORM_QUANT_FUSION** | bool | (master switch) | *Deprecated.* Override for DeepSeek QK-norm fusion. Falls back to `ATOM_ENABLE_RMSNORM_QUANT_FUSION` when unset. | ### Qwen3-MoE style diff --git a/docs/model_ops_guide.md b/docs/model_ops_guide.md index 969082719..044f04d12 100644 --- a/docs/model_ops_guide.md +++ b/docs/model_ops_guide.md @@ -18,6 +18,7 @@ ATOM (AiTer Optimized Model) wraps AITER kernels with model-level abstractions f | `MLAAttention` | `attention_mla.py` | `mla_decode_fwd`, `mla_prefill_fwd`, `concat_and_cache_mla`, `fused_qk_rope_concat_and_cache_mla` | Multi-head latent attention | | `FusedMoE` | `moe.py` | `aiter.fused_moe.fused_moe`, `asm_moe` | Mixture of experts | | `RMSNorm` | `layernorm.py` | `rmsnorm2d_fwd`, `rmsnorm2d_fwd_with_add`, `fused_add_rmsnorm_pad` | RMS normalization | +| `DualRMSNorm` | `layernorm.py` | `fuse_rmsnorm_group_quant` | Fused dual RMSNorm + quant (MLA q/kv norms) | | `LayerNorm` | `layernorm.py` | `layernorm2d_fwd`, `layernorm2d_fwd_with_add` | Layer normalization | | `SiluAndMul` | `activation.py` | `aiter.silu_and_mul` | SiLU gated activation | | `VocabParallelEmbedding` | `embed_head.py` | `F.embedding` + TP all-reduce | Vocab embedding | @@ -353,7 +354,22 @@ RMSNorm( ) ``` -### 5.2 `LayerNorm` (`layernorm.py`) +### 5.2 `DualRMSNorm` (`layernorm.py`) + +`DualRMSNorm` fuses normalization of two tensors (e.g. q_c and kv_c in MLA) and quantization of the first into a single `fuse_rmsnorm_group_quant` kernel call. It references existing `RMSNorm` modules' weights rather than creating its own, so checkpoint loading works correctly. + +```python +DualRMSNorm( + norm1: RMSNorm, # e.g. q_a_layernorm + norm2: RMSNorm, # e.g. kv_a_layernorm + quant_config: Optional[QuantizationConfig] = None, + transpose_scale: bool = False, + shuffle: bool = False, +) +# forward(x1, x2) -> ((x1_quant, x1_scale), x2_normed) +``` + +### 5.3 `LayerNorm` (`layernorm.py`) `LayerNorm` wraps `layernorm2d_fwd` and `layernorm2d_fwd_with_add` (with bias support): @@ -482,6 +498,7 @@ ATOM uses fused kernels to reduce memory traffic by combining multiple operation |---|---|---|---| | RMSNorm + FP8 quant | RMSNorm, per-tensor FP8 static quant | `RMSNorm(fused_quant=True)` + `x_scale` | `fused_rms_fp8_per_tensor_static_quant` | | RMSNorm + MXFP4 quant | RMSNorm, per-1x32 MXFP4 quant | `RMSNorm(fused_quant=True)` + `QuantType.per_1x32` | `fused_rms_mxfp4_quant` | +| Dual RMSNorm + quant | Normalize q_c + kv_c, quantize q_c | `DualRMSNorm` (MLA, FP8/FP4) | `fuse_rmsnorm_group_quant` | | RMSNorm + add + pad | Residual add, RMSNorm, output padding | `RMSNorm(x_pad_to_multiple>0)` | `fused_add_rmsnorm_pad` | | AllReduce + RMSNorm | TP all-reduce, RMSNorm | `RMSNorm(fused_allreduce=True)` | `tensor_model_parallel_fused_allreduce_rmsnorm` | | SiLU + mul + FP8 quant | SiLU activation, multiply, FP8 quant | `SiluAndMul(fused_quant=True)` + `x_scale` | `fused_silu_mul_fp8_per_tensor_static_quant` | @@ -504,7 +521,7 @@ ATOM uses fused kernels to reduce memory traffic by combining multiple operation |---|---| | `linear.py` | `LinearBase`, `ColumnParallelLinear`, `RowParallelLinear`, `QKVParallelLinear`, `MergedColumnParallelLinear`, `ReplicatedLinear`, `MergedReplicatedLinear` | | `activation.py` | `SiluAndMul` with fused FP8/MXFP4 quantization | -| `layernorm.py` | `RMSNorm`, `LayerNorm` with fused allreduce/quant/pad variants | +| `layernorm.py` | `RMSNorm`, `DualRMSNorm`, `LayerNorm` with fused allreduce/quant/pad variants; group-quant dispatch (`fuse_rmsnorm_group_quant`) | | `base_attention.py` | Top-level `Attention` dispatcher with custom op registration | | `attention_mha.py` | MHA implementation: prefill (flash), decode (ASM/Triton paged attention) | | `attention_mla.py` | `MLAAttention`, `MLAModules` -- DeepSeek MLA with compressed KV |