diff --git a/README.md b/README.md index 6c2d41a4..4de8dd8c 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress

## 📣Latest News +- [26/04/23] We now support FP8-Static quantization for **Hy3-preview** (MoE A20B). - [26/03/25] We have released **DAQ**, the quantization algorithm that preserves the knowledge acquired while the update of parameters is relatively small during post-training training.[[Paper]](https://arxiv.org/abs/2603.22324) | [[Docs]](docs/source/features/quantization/daq.md) - [26/02/09] We have released HY-1.8B-2Bit, 2bit on-device large language model,[[Huggingface]](https://huggingface.co/AngelSlim/HY-1.8B-2Bit). - [26/01/13] We have released v0.3. We support the training and deployment of Eagle3 for all-scale LLMs/VLMs/Audio models, as detailed in the [guidance documentation](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/eagle/index.html). And We released **Sherry**, the hardware-efficient 1.25 bit quantization algorithm [[Paper]](https://arxiv.org/abs/2601.07892) | [[Code]](https://github.com/Tencent/AngelSlim/tree/sherry/Sherry)🔥🔥🔥 @@ -253,6 +254,12 @@ python3 tools/run.py -c configs/qwen3/fp8_static/qwen3-1_7b_fp8_static.yaml This example produces quantized model weights by performing PTQ calibration on a model loaded from HuggingFace. +For **Hy3-preview** (MoE A20B) FP8-Static quantization: + +```shell +python tools/run.py -c configs/hunyuan/fp8_static/hunyuanv3_a20b_fp8_static_c8.yaml +``` +
Code-based Start diff --git a/README_cn.md b/README_cn.md index ae1fc891..e775c4f9 100644 --- a/README_cn.md +++ b/README_cn.md @@ -22,6 +22,7 @@

## 📣最新进展 +- [26/04/23] 我们支持了 **Hy3-preview**(MoE A20B)模型的 FP8-Static 量化。 - [26/03/25] 我们发布了量化算法DAQ,该方法在后训练参数更新较小时,可保留量化后模型能力 [[论文]](https://arxiv.org/abs/2603.22324) | [[文档]](docs/source/features/quantization/daq.md) - [26/02/09] 我们发布了 HY-1.8B-2Bit, 2比特端侧大模型, 模型可见[[Huggingface]](https://huggingface.co/AngelSlim/HY-1.8B-2Bit). - [26/01/13] 我们发布V0.3版本, 支持了全模态场景的投机采样训练及部署,文档:[Eagle3 for LLM/VLM/Audio](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/eagle/index.html)。并且我们发布了 **Sherry** 新的硬件高效的1.25bit三值量化算法 [[论文]](https://arxiv.org/abs/2601.07892) | [[代码]](https://github.com/Tencent/AngelSlim/tree/sherry/Sherry)🔥🔥🔥 @@ -252,6 +253,12 @@ bash scripts/speculative/train_eagle3_online.sh 该示例将会加载`HugggingFace`模型进行PTQ量化校准,最终量化产出模型权重. +对 **Hy3-preview**(MoE A20B)进行 FP8-Static 量化: + + ```shell + python tools/run.py -c configs/hunyuan/fp8_static/hunyuanv3_a20b_fp8_static_c8.yaml + ``` +
2、源码启动 diff --git a/angelslim/compressor/quant/core/config.py b/angelslim/compressor/quant/core/config.py index 872cba57..1ae16296 100644 --- a/angelslim/compressor/quant/core/config.py +++ b/angelslim/compressor/quant/core/config.py @@ -31,7 +31,10 @@ "per-group": AbsMaxGroupWiseWeightObserver, } -KVCACHE_OBSERVERS_CLASS = {"per-channel": AbsmaxPerchannelObserver} +KVCACHE_OBSERVERS_CLASS = { + "per-channel": AbsmaxPerchannelObserver, + "per-tensor": AbsmaxPertensorObserver, +} class QuantConfig: @@ -60,6 +63,7 @@ def __init__(self, config, global_config=None): self.quant_helpers = quantization_args.quant_helpers act_quant_method = quantization_args.quant_method.get("activation", None) weight_quant_method = quantization_args.quant_method["weight"] + kv_cache_quant_method = quantization_args.quant_method.get("kv_cache", None) self.cpu_convert = quantization_args.cpu_convert self.save_name = quantization_args.save_name @@ -77,7 +81,11 @@ def __init__(self, config, global_config=None): ACT_OBSERVERS_CLASS[act_quant_method] if "static" in is_dynamic else None ) self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method] - self.kv_cache_observer = None + self.kv_cache_observer = ( + KVCACHE_OBSERVERS_CLASS[kv_cache_quant_method] + if kv_cache_quant_method is not None + else None + ) if "w4a8" in self.quant_algo: group_size = ( @@ -98,6 +106,8 @@ def __init__(self, config, global_config=None): if act_quant_method is not None: self.quant_algo_info["a"] = f"fp8_{act_quant_method}-{is_dynamic}" + if kv_cache_quant_method is not None: + self.quant_algo_info["c"] = f"fp8_{kv_cache_quant_method}" self.low_memory = config.quantization.low_memory self.quant_analyse = config.quantization.quant_analyse self.quant_vit = config.quantization.quant_vit @@ -117,13 +127,19 @@ def __init__(self, config, global_config=None): ACT_OBSERVERS_CLASS[act_quant_method] if "static" in is_dynamic else None ) self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method] - self.kv_cache_observer = None + self.kv_cache_observer = ( + KVCACHE_OBSERVERS_CLASS[kv_cache_quant_method] + if kv_cache_quant_method is not None + else None + ) self.quant_algo_info = { "w": f"int8_{weight_quant_method}", "ignore_layers": quantization_args.ignore_layers, } if act_quant_method is not None: self.quant_algo_info["a"] = f"int8_{act_quant_method}-{is_dynamic}" + if kv_cache_quant_method is not None: + self.quant_algo_info["c"] = f"int8_{kv_cache_quant_method}" self.low_memory = config.quantization.low_memory self.quant_analyse = config.quantization.quant_analyse elif "int4_awq" in self.quant_algo: diff --git a/angelslim/compressor/quant/core/hook.py b/angelslim/compressor/quant/core/hook.py index 893b34e0..ecea808f 100644 --- a/angelslim/compressor/quant/core/hook.py +++ b/angelslim/compressor/quant/core/hook.py @@ -51,7 +51,9 @@ def apply_hook(self): sub_layer, act_observer, weight_observer, - kv_cache_observer if name in self.kv_names else None, + # kv_cache_observer is now handled by monkey patching at attention level + # so we pass None here + None, self.quant_model.quant_algo_dict, **extra_kwargs ) @@ -59,6 +61,14 @@ def apply_hook(self): self.observer_dict[sub_layer] = observer self._forward_hook_list.append(forward_hook_handle) + # Apply KV cache observers using monkey patching (for attention-level observation) + if kv_cache_observer is not None and hasattr(self.quant_model, "apply_kvcache_observers"): + quant_bits = self.quant_model.quant_algo_dict.get("c_quant_bits", 8) + self.quant_model.apply_kvcache_observers( + kv_cache_observer_class=kv_cache_observer, + quant_bits=quant_bits, + ) + def apply_smooth_hook(self, smooth_mapping_layers, smooth_observer): for smooth_layer, _ in smooth_mapping_layers.values(): observer = PTQObserver( @@ -86,6 +96,9 @@ def remove_hook(self): for hook in self._forward_hook_list: hook.remove() self._forward_hook_list = [] + # Remove KV cache observer patches if available + if hasattr(self.quant_model, "remove_kvcache_observers"): + self.quant_model.remove_kvcache_observers() def post_process(self): maxval = get_fp_maxval(bits=8) @@ -109,5 +122,8 @@ def post_process(self): self.quant_model.act_scales_dict[name] / maxval.type(act_dtype) ) if self.quant_model.quant_algo_dict["c_quant_algo"] == "fp8": - for k, v in self.quant_model.kv_cache_scales_dict.items(): - self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype) + # Process KV cache scales from attention-level observers + if hasattr(self.quant_model, "get_kvcache_scales"): + kv_scales = self.quant_model.get_kvcache_scales() + for k, v in kv_scales.items(): + self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype) diff --git a/angelslim/compressor/quant/core/save.py b/angelslim/compressor/quant/core/save.py index 3fab500e..53f5e446 100644 --- a/angelslim/compressor/quant/core/save.py +++ b/angelslim/compressor/quant/core/save.py @@ -249,6 +249,17 @@ def save(self, save_path): raise ValueError(f"{self.quant_model.quant_config.quant_algo} not supported") quantization_config = {"quant_method": save_name, ignore_field: ignored_layers} + # Set kv_cache_scheme if kv_cache quantization is enabled + c_quant_algo = self.quant_model.quant_config.quant_algo_info.get("c", None) + if c_quant_algo is not None: + kv_cache_scheme = { + "num_bits": 8, + "strategy": re.search(r"per-([a-zA-Z]+)", c_quant_algo).group(1), + "type": "float", + } + else: + kv_cache_scheme = None + if save_name == "compressed-tensors": quantization_config.update( { @@ -260,13 +271,15 @@ def save(self, save_path): "targets": ["Linear"], } }, - "kv_cache_scheme": None, + "kv_cache_scheme": kv_cache_scheme, "format": quant_format, "quantization_status": "compressed", } ) else: quantization_config["activation_scheme"] = "dynamic" if is_dynamic else "static" + if kv_cache_scheme is not None: + quantization_config["kv_cache_scheme"] = "static" if ( hasattr(self.quant_model.quant_config, "transform_config") @@ -287,6 +300,25 @@ def save(self, save_path): json.dump(trtllm_config, f, indent=4) self.quant_model.tokenizer.save_pretrained(save_path) + # Save KV cache scales if available + if ( + hasattr(self.quant_model, "kv_cache_scales_dict") + and self.quant_model.kv_cache_scales_dict + ): + kv_scales_path = os.path.join(save_path, "kv_cache_scales.safetensors") + kv_scales_dict = {} + kv_scale_map = {} + for name, scale in self.quant_model.kv_cache_scales_dict.items(): + kv_scales_dict[name] = scale + kv_scale_map[name] = "kv_cache_scales.safetensors" + safe_save(kv_scales_dict, kv_scales_path) + print_info("Save KV cache scales to: {}".format(kv_scales_path)) + new_model_index_file = os.path.join(save_path, "model.safetensors.index.json") + with open(new_model_index_file, "r") as f: + new_model_index = json.load(f) + new_model_index["weight_map"].update(kv_scale_map) + with open(os.path.join(save_path, "model.safetensors.index.json"), "w") as f: + json.dump(new_model_index, f, indent=2) class PTQOnlyScaleSave(PTQSaveBase): diff --git a/angelslim/compressor/quant/ptq.py b/angelslim/compressor/quant/ptq.py index 3044d97d..4de5f20e 100644 --- a/angelslim/compressor/quant/ptq.py +++ b/angelslim/compressor/quant/ptq.py @@ -205,7 +205,58 @@ def save(self, save_path: str): save_func = self.quant_model.get_save_func()(self.quant_model) save_func.save(save_path) + def get_meta_weights_info(self, model): + """获取所有meta权重的详细信息""" + meta_params = [] + + for name, param in model.named_parameters(): + if param.device.type == "meta": + meta_params.append( + { + "name": name, + } + ) + return meta_params + + def set_meta_weights_info(self, model): + """替换所有meta权重""" + orign_w_dict = {} + for name, param in model.named_parameters(): + if param.device.type == "meta": + with open( + os.path.join(self.absolute_model_path, "model.safetensors.index.json"), + "r", + ) as f: + model_index = json.load(f) + orign_w_file = os.path.join( + self.absolute_model_path, + model_index["weight_map"][name], + ) + if orign_w_file in orign_w_dict.keys(): + orign_w = orign_w_dict[orign_w_file] + else: + orign_w = load_file(orign_w_file, device="cpu") + orign_w_dict[orign_w_file] = orign_w + + empty_tensor = torch.empty(param.data.shape, dtype=param.data.dtype, device="cpu") + new_param = torch.nn.Parameter(empty_tensor) + new_param.data = orign_w[name] + parts = name.split(".") + current_module = model + + # 导航到包含参数的模块 + for part in parts[:-1]: + current_module = getattr(current_module, part) + + # 设置新的参数 + setattr(current_module, parts[-1], new_param) + + del orign_w_dict + def _convert(self): + self.set_meta_weights_info(self.quant_model.model) + print_info(f"Meta weight:{self.get_meta_weights_info(self.quant_model.model)}") + # 1. get act, weight and kv-cache scale for name, sub_layer in self.ptq_hook.quant_layers_dict.items(): if ( diff --git a/angelslim/engine.py b/angelslim/engine.py index 6b21e647..757d8741 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -121,7 +121,6 @@ def prepare_model( low_cpu_mem_usage=low_cpu_mem_usage, use_cache=use_cache, using_multi_nodes=using_multi_nodes, - attn_implementation=attn_implementation, ) self.model_path = model_path elif self.series in ["Omni"]: diff --git a/angelslim/models/llm/__init__.py b/angelslim/models/llm/__init__.py index e8735382..6a6a7a1a 100644 --- a/angelslim/models/llm/__init__.py +++ b/angelslim/models/llm/__init__.py @@ -16,6 +16,7 @@ from .glm import GLM # noqa: F401 from .hunyuan_dense import HunyuanDense # noqa: F401 from .hunyuan_moe import HunyuanMoE # noqa: F401 +from .hunyuan_v3_moe import HYV3MoE # noqa: F401 from .kimi_k2 import KimiK2 # noqa: F401 from .llama import Llama # noqa: F401 from .qwen import Qwen # noqa: F401 diff --git a/angelslim/models/llm/hunyuan_v3_moe.py b/angelslim/models/llm/hunyuan_v3_moe.py new file mode 100644 index 00000000..d8919643 --- /dev/null +++ b/angelslim/models/llm/hunyuan_v3_moe.py @@ -0,0 +1,385 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.hy_v3.modeling_hy_v3 import ( + ALL_ATTENTION_FUNCTIONS, + HYV3Experts, + apply_rotary_pos_emb, + eager_attention_forward, +) + +from ...compressor.quant.core import PTQSaveVllmHF +from ...utils.utils import find_layers, find_parent_layer_and_sub_name +from ..base_model import BaseLLMModel +from ..model_factory import SlimModelFactory + + +class HYV3ExpertsWithLinear(HYV3Experts): + """Wrapper around HYV3Experts that exposes per-expert weights as nn.Linear modules. + + HYV3Experts stores all expert weights as 3-D nn.Parameter tensors, which are + invisible to AngelSlim's find_layers() and PTQ hook (both only recognise + nn.Linear). This wrapper splits those tensors into individual nn.Linear + modules at construction time so that the standard quantisation pipeline can + observe and quantise them. + + Weight shape mapping + -------------------- + gate_up_proj : [num_experts, 2*intermediate_dim, hidden_dim] + gate_up_proj[i] → chunk(2, dim=0) + gate_proj[i].weight : [intermediate_dim, hidden_dim] + up_proj[i].weight : [intermediate_dim, hidden_dim] + down_proj : [num_experts, hidden_dim, intermediate_dim] + down_proj[i] → down_proj[i].weight : [hidden_dim, intermediate_dim] + """ + + def __init__(self, experts_layer): + # Bypass HYV3Experts.__init__ to avoid allocating large empty Parameter + # tensors that we would immediately overwrite. HYV3Experts does not + # store self.config, so we copy the required scalar attributes directly. + nn.Module.__init__(self) + self.num_experts = experts_layer.num_experts + self.hidden_dim = experts_layer.hidden_dim + self.intermediate_dim = experts_layer.intermediate_dim + self.act_fn = experts_layer.act_fn + + for expert_idx in range(self.num_experts): + expert = nn.ModuleDict( + { + "gate_proj": nn.Linear(self.hidden_dim, self.intermediate_dim, bias=False), + "up_proj": nn.Linear(self.hidden_dim, self.intermediate_dim, bias=False), + "down_proj": nn.Linear(self.intermediate_dim, self.hidden_dim, bias=False), + } + ) + # gate_up_proj[i]: [2*intermediate_dim, hidden_dim] + # chunk on dim=0 → [intermediate_dim, hidden_dim] each + expert["gate_proj"].weight.data, expert["up_proj"].weight.data = ( + experts_layer.gate_up_proj[expert_idx].chunk(2, dim=0) + ) + # down_proj[i]: [hidden_dim, intermediate_dim] + expert["down_proj"].weight.data = experts_layer.down_proj[expert_idx] + setattr(self, f"{expert_idx}", expert) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + expert_layer = getattr(self, f"{expert_idx}") + gate = expert_layer["gate_proj"](current_state) + up = expert_layer["up_proj"](current_state) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = expert_layer["down_proj"](current_hidden_states) + current_hidden_states = ( + current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + ) + final_hidden_states.index_add_( + 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) + ) + + return final_hidden_states + + +@SlimModelFactory.register +class HYV3MoE(BaseLLMModel): + def __init__( + self, + model=None, + deploy_backend="vllm", + ): + super().__init__( + model=model, + deploy_backend=deploy_backend, + ) + self.block_name = "model.layers" + # Store original forward methods for restoration + self._original_attn_forwards = {} + # Store KV cache observers: {attn_layer_name: {"key_observer": ..., "value_observer": ...}} + self.kv_cache_observers = {} + + def from_pretrained( + self, + model_path, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True, + low_cpu_mem_usage=True, + use_cache=False, + using_multi_nodes=False, + ): + attn_implementation = "eager" + torch_dtype = torch.bfloat16 + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype, + device_map=device_map, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=low_cpu_mem_usage, + use_cache=use_cache, + ) + + # Load tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + + def replace_moe(self): + """Replace HYV3Experts instances with HYV3ExpertsWithLinear. + + This must be called before init_ptq() so that find_layers() can discover + the per-expert nn.Linear modules and register them with the PTQ hook. + """ + for name, module in self.model.named_modules(): + if isinstance(module, HYV3Experts) and not isinstance(module, HYV3ExpertsWithLinear): + parent_layer, sub_name = find_parent_layer_and_sub_name(self.model, name) + moe_linear = HYV3ExpertsWithLinear(module) + del module + setattr(parent_layer, sub_name, moe_linear) + + def init_ptq(self, slim_config): + self.replace_moe() + super().init_ptq(slim_config) + + def get_observer_layers(self): + names = [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.up_proj", + "mlp.down_proj", + "shared_mlp.gate_proj", + "shared_mlp.up_proj", + "shared_mlp.down_proj", + ] + expert_pattern = [ + r"model\.layers\.\d+\.mlp\.experts\.\d+\.gate_proj", + r"model\.layers\.\d+\.mlp\.experts\.\d+\.up_proj", + r"model\.layers\.\d+\.mlp\.experts\.\d+\.down_proj", + ] + + obs_layers = [nn.Linear] + observer_layers_dict = find_layers(self.model, layers=obs_layers) + + compiled_patterns = [re.compile(pattern) for pattern in expert_pattern] + + observer_layers_dict = { + k: v + for k, v in observer_layers_dict.items() + if k.startswith(self.block_name) + and ( + any(name in k for name in names) + or any(pattern.search(k) for pattern in compiled_patterns) + ) + } + + if self.quant_config.custom_observe_layers_names != "default": + for custom_observe_name in self.quant_config.custom_observe_layers_names: + for default_name in observer_layers_dict.keys(): + if custom_observe_name not in default_name: + observer_layers_dict.pop(default_name) + return observer_layers_dict + + def get_parent_dict(self, observer_layers_dict): + parent_mapping = {r"experts\.\d+": "experts"} + parent_dict = {} + for layer_name in observer_layers_dict.keys(): + parent_name = layer_name + for k, v in parent_mapping.items(): + parent_name = re.sub(k, v, layer_name) + if parent_name != layer_name: + parent_dict[layer_name] = parent_name + return parent_dict + + def get_kvcache_observer_layers_names(self, observe_names): + """Return empty list since we use attention-level patching for KV cache.""" + # Return empty list to disable the default k_proj/v_proj output observation + # We will use apply_kvcache_observers() instead for RoPE-after key/value states + return [] + + def get_attention_layers(self): + """Get all attention layers in the model.""" + attention_layers = {} + for name, module in self.model.named_modules(): + if name.endswith(".self_attn") and hasattr(module, "forward"): + # Verify it has k_proj and v_proj attributes + if hasattr(module, "k_proj") and hasattr(module, "v_proj"): + attention_layers[name] = module + return attention_layers + + def apply_kvcache_observers(self, kv_cache_observer_class, quant_bits=8): + """ + Apply KV cache observers to attention layers using monkey patching. + This observes key_states and value_states AFTER RoPE is applied. + + Args: + kv_cache_observer_class: The observer class to use (e.g., AbsmaxPertensorObserver) + quant_bits: Quantization bits for the observer + """ + from ...compressor.quant.observers import AbsmaxPertensorObserver + + if kv_cache_observer_class is None: + kv_cache_observer_class = AbsmaxPertensorObserver + + attention_layers = self.get_attention_layers() + + for attn_name, attn_module in attention_layers.items(): + # Create observers for key and value states + key_observer = kv_cache_observer_class( + layer=attn_module.k_proj, + quant_bits=quant_bits, + ) + value_observer = kv_cache_observer_class( + layer=attn_module.v_proj, + quant_bits=quant_bits, + ) + + # Store observers + self.kv_cache_observers[attn_name] = { + "key_observer": key_observer, + "value_observer": value_observer, + } + + # Save original forward + self._original_attn_forwards[attn_name] = attn_module.forward + + # Create patched forward + self._patch_attention_forward(attn_module, attn_name) + + def _patch_attention_forward(self, attn_module, attn_name): + """ + Patch the attention module's forward method to observe KV cache after RoPE. + + Adapted to the new transformers ``HYV3Attention.forward`` signature, where + rotary embeddings are pre-computed and passed in as ``position_embeddings`` + (a ``(cos, sin)`` tuple), ``q_norm``/``k_norm`` are applied unconditionally + on the pre-transpose view, and attention dispatch goes through + ``ALL_ATTENTION_FUNCTIONS``. + """ + key_observer = self.kv_cache_observers[attn_name]["key_observer"] + value_observer = self.kv_cache_observers[attn_name]["value_observer"] + + def patched_forward( + hidden_states, + position_embeddings, + attention_mask, + past_key_values=None, + cache_position=None, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, attn_module.head_dim) + + query_states = attn_module.q_proj(hidden_states).view(hidden_shape) + key_states = attn_module.k_proj(hidden_states).view(hidden_shape) + value_states = attn_module.v_proj(hidden_states).view(hidden_shape) + + query_states = attn_module.q_norm(query_states) + key_states = attn_module.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # === OBSERVE KV CACHE AFTER RoPE === + key_observer(key_states) + value_observer(value_states) + # === END OBSERVE === + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, attn_module.layer_idx, cache_kwargs + ) + + attention_interface = eager_attention_forward + if attn_module.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + attn_module.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + attn_module, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not attn_module.training else attn_module.attention_dropout, + scaling=attn_module.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_module.o_proj(attn_output) + return attn_output, attn_weights + + # Replace the forward method + attn_module.forward = patched_forward + + def remove_kvcache_observers(self): + """Remove patched forward methods and restore original ones.""" + for attn_name, original_forward in self._original_attn_forwards.items(): + # Find the attention module and restore its forward + parts = attn_name.split(".") + module = self.model + for part in parts: + module = getattr(module, part) + module.forward = original_forward + + self._original_attn_forwards.clear() + + def get_kvcache_scales(self): + """ + Get KV cache scales from observers. + Returns dict with format: {"layer_name.k_cache.scale": scale, + "layer_name.v_cache.scale": scale} + """ + kv_scales = {} + for attn_name, observers in self.kv_cache_observers.items(): + key_scale = observers["key_observer"].scales() + value_scale = observers["value_observer"].scales() + kv_scales[f"{attn_name}.k_cache.scale"] = key_scale + kv_scales[f"{attn_name}.v_cache.scale"] = value_scale + return kv_scales + + def get_save_func(self): + if self.deploy_backend in ["vllm", "huggingface"]: + return PTQSaveVllmHF + else: + raise NotImplementedError( + f"deploy_backend {self.deploy_backend} is not supported for saving." + ) diff --git a/configs/hunyuan/fp8_static/hunyuanv3_a20b_fp8_static_c8.yaml b/configs/hunyuan/fp8_static/hunyuanv3_a20b_fp8_static_c8.yaml new file mode 100644 index 00000000..34217472 --- /dev/null +++ b/configs/hunyuan/fp8_static/hunyuanv3_a20b_fp8_static_c8.yaml @@ -0,0 +1,37 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: HYV3MoE + model_path: tencent/Hy3-preview + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + kv_cache: "per-tensor" + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + cpu_convert: true + save_name: "fp8" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4/sharegpt_gpt4_256.jsonl + max_seq_length: 2048 + num_samples: 512 + batch_size: 1