Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress
</p>

## 📣Latest News
- [26/04/23] We now support FP8-Static quantization for **Hy3-preview** (MoE A20B).
- [26/03/25] We have released **DAQ**, the quantization algorithm that preserves the knowledge acquired while the update of parameters is relatively small during post-training training.[[Paper]](https://arxiv.org/abs/2603.22324) | [[Docs]](docs/source/features/quantization/daq.md)
- [26/02/09] We have released HY-1.8B-2Bit, 2bit on-device large language model,[[Huggingface]](https://huggingface.co/AngelSlim/HY-1.8B-2Bit).
- [26/01/13] We have released v0.3. We support the training and deployment of Eagle3 for all-scale LLMs/VLMs/Audio models, as detailed in the [guidance documentation](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/eagle/index.html). And We released **Sherry**, the hardware-efficient 1.25 bit quantization algorithm [[Paper]](https://arxiv.org/abs/2601.07892) | [[Code]](https://github.com/Tencent/AngelSlim/tree/sherry/Sherry)🔥🔥🔥
Expand Down Expand Up @@ -253,6 +254,12 @@ python3 tools/run.py -c configs/qwen3/fp8_static/qwen3-1_7b_fp8_static.yaml

This example produces quantized model weights by performing PTQ calibration on a model loaded from HuggingFace.

For **Hy3-preview** (MoE A20B) FP8-Static quantization:

```shell
python tools/run.py -c configs/hunyuan/fp8_static/hunyuanv3_a20b_fp8_static_c8.yaml
```

<details>
<summary>Code-based Start</summary>

Expand Down
7 changes: 7 additions & 0 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
</p>

## 📣最新进展
- [26/04/23] 我们支持了 **Hy3-preview**(MoE A20B)模型的 FP8-Static 量化。
- [26/03/25] 我们发布了量化算法DAQ,该方法在后训练参数更新较小时,可保留量化后模型能力 [[论文]](https://arxiv.org/abs/2603.22324) | [[文档]](docs/source/features/quantization/daq.md)
- [26/02/09] 我们发布了 HY-1.8B-2Bit, 2比特端侧大模型, 模型可见[[Huggingface]](https://huggingface.co/AngelSlim/HY-1.8B-2Bit).
- [26/01/13] 我们发布V0.3版本, 支持了全模态场景的投机采样训练及部署,文档:[Eagle3 for LLM/VLM/Audio](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/eagle/index.html)。并且我们发布了 **Sherry** 新的硬件高效的1.25bit三值量化算法 [[论文]](https://arxiv.org/abs/2601.07892) | [[代码]](https://github.com/Tencent/AngelSlim/tree/sherry/Sherry)🔥🔥🔥
Expand Down Expand Up @@ -252,6 +253,12 @@ bash scripts/speculative/train_eagle3_online.sh

该示例将会加载`HugggingFace`模型进行PTQ量化校准,最终量化产出模型权重.

对 **Hy3-preview**(MoE A20B)进行 FP8-Static 量化:

```shell
python tools/run.py -c configs/hunyuan/fp8_static/hunyuanv3_a20b_fp8_static_c8.yaml
```

<details>
<summary>2、源码启动</summary>

Expand Down
22 changes: 19 additions & 3 deletions angelslim/compressor/quant/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
"per-group": AbsMaxGroupWiseWeightObserver,
}

KVCACHE_OBSERVERS_CLASS = {"per-channel": AbsmaxPerchannelObserver}
KVCACHE_OBSERVERS_CLASS = {
"per-channel": AbsmaxPerchannelObserver,
"per-tensor": AbsmaxPertensorObserver,
}


class QuantConfig:
Expand Down Expand Up @@ -60,6 +63,7 @@ def __init__(self, config, global_config=None):
self.quant_helpers = quantization_args.quant_helpers
act_quant_method = quantization_args.quant_method.get("activation", None)
weight_quant_method = quantization_args.quant_method["weight"]
kv_cache_quant_method = quantization_args.quant_method.get("kv_cache", None)
self.cpu_convert = quantization_args.cpu_convert
self.save_name = quantization_args.save_name

Expand All @@ -77,7 +81,11 @@ def __init__(self, config, global_config=None):
ACT_OBSERVERS_CLASS[act_quant_method] if "static" in is_dynamic else None
)
self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method]
self.kv_cache_observer = None
self.kv_cache_observer = (
KVCACHE_OBSERVERS_CLASS[kv_cache_quant_method]
if kv_cache_quant_method is not None
else None
)

if "w4a8" in self.quant_algo:
group_size = (
Expand All @@ -98,6 +106,8 @@ def __init__(self, config, global_config=None):

if act_quant_method is not None:
self.quant_algo_info["a"] = f"fp8_{act_quant_method}-{is_dynamic}"
if kv_cache_quant_method is not None:
self.quant_algo_info["c"] = f"fp8_{kv_cache_quant_method}"
self.low_memory = config.quantization.low_memory
self.quant_analyse = config.quantization.quant_analyse
self.quant_vit = config.quantization.quant_vit
Expand All @@ -117,13 +127,19 @@ def __init__(self, config, global_config=None):
ACT_OBSERVERS_CLASS[act_quant_method] if "static" in is_dynamic else None
)
self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method]
self.kv_cache_observer = None
self.kv_cache_observer = (
KVCACHE_OBSERVERS_CLASS[kv_cache_quant_method]
if kv_cache_quant_method is not None
else None
)
self.quant_algo_info = {
"w": f"int8_{weight_quant_method}",
"ignore_layers": quantization_args.ignore_layers,
}
if act_quant_method is not None:
self.quant_algo_info["a"] = f"int8_{act_quant_method}-{is_dynamic}"
if kv_cache_quant_method is not None:
self.quant_algo_info["c"] = f"int8_{kv_cache_quant_method}"
self.low_memory = config.quantization.low_memory
self.quant_analyse = config.quantization.quant_analyse
elif "int4_awq" in self.quant_algo:
Expand Down
22 changes: 19 additions & 3 deletions angelslim/compressor/quant/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,24 @@ def apply_hook(self):
sub_layer,
act_observer,
weight_observer,
kv_cache_observer if name in self.kv_names else None,
# kv_cache_observer is now handled by monkey patching at attention level
# so we pass None here
None,
self.quant_model.quant_algo_dict,
**extra_kwargs
)
forward_hook_handle = sub_layer.register_forward_hook(self._forward_hook)
self.observer_dict[sub_layer] = observer
self._forward_hook_list.append(forward_hook_handle)

# Apply KV cache observers using monkey patching (for attention-level observation)
if kv_cache_observer is not None and hasattr(self.quant_model, "apply_kvcache_observers"):
quant_bits = self.quant_model.quant_algo_dict.get("c_quant_bits", 8)
self.quant_model.apply_kvcache_observers(
kv_cache_observer_class=kv_cache_observer,
quant_bits=quant_bits,
)

def apply_smooth_hook(self, smooth_mapping_layers, smooth_observer):
for smooth_layer, _ in smooth_mapping_layers.values():
observer = PTQObserver(
Expand Down Expand Up @@ -86,6 +96,9 @@ def remove_hook(self):
for hook in self._forward_hook_list:
hook.remove()
self._forward_hook_list = []
# Remove KV cache observer patches if available
if hasattr(self.quant_model, "remove_kvcache_observers"):
self.quant_model.remove_kvcache_observers()

def post_process(self):
maxval = get_fp_maxval(bits=8)
Expand All @@ -109,5 +122,8 @@ def post_process(self):
self.quant_model.act_scales_dict[name] / maxval.type(act_dtype)
)
if self.quant_model.quant_algo_dict["c_quant_algo"] == "fp8":
for k, v in self.quant_model.kv_cache_scales_dict.items():
self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype)
# Process KV cache scales from attention-level observers
if hasattr(self.quant_model, "get_kvcache_scales"):
kv_scales = self.quant_model.get_kvcache_scales()
for k, v in kv_scales.items():
self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype)
34 changes: 33 additions & 1 deletion angelslim/compressor/quant/core/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,17 @@ def save(self, save_path):
raise ValueError(f"{self.quant_model.quant_config.quant_algo} not supported")

quantization_config = {"quant_method": save_name, ignore_field: ignored_layers}
# Set kv_cache_scheme if kv_cache quantization is enabled
c_quant_algo = self.quant_model.quant_config.quant_algo_info.get("c", None)
if c_quant_algo is not None:
kv_cache_scheme = {
"num_bits": 8,
"strategy": re.search(r"per-([a-zA-Z]+)", c_quant_algo).group(1),
"type": "float",
}
else:
kv_cache_scheme = None

if save_name == "compressed-tensors":
quantization_config.update(
{
Expand All @@ -260,13 +271,15 @@ def save(self, save_path):
"targets": ["Linear"],
}
},
"kv_cache_scheme": None,
"kv_cache_scheme": kv_cache_scheme,
"format": quant_format,
"quantization_status": "compressed",
}
)
else:
quantization_config["activation_scheme"] = "dynamic" if is_dynamic else "static"
if kv_cache_scheme is not None:
quantization_config["kv_cache_scheme"] = "static"

if (
hasattr(self.quant_model.quant_config, "transform_config")
Expand All @@ -287,6 +300,25 @@ def save(self, save_path):
json.dump(trtllm_config, f, indent=4)

self.quant_model.tokenizer.save_pretrained(save_path)
# Save KV cache scales if available
if (
hasattr(self.quant_model, "kv_cache_scales_dict")
and self.quant_model.kv_cache_scales_dict
):
kv_scales_path = os.path.join(save_path, "kv_cache_scales.safetensors")
kv_scales_dict = {}
kv_scale_map = {}
for name, scale in self.quant_model.kv_cache_scales_dict.items():
kv_scales_dict[name] = scale
kv_scale_map[name] = "kv_cache_scales.safetensors"
safe_save(kv_scales_dict, kv_scales_path)
print_info("Save KV cache scales to: {}".format(kv_scales_path))
new_model_index_file = os.path.join(save_path, "model.safetensors.index.json")
with open(new_model_index_file, "r") as f:
new_model_index = json.load(f)
new_model_index["weight_map"].update(kv_scale_map)
with open(os.path.join(save_path, "model.safetensors.index.json"), "w") as f:
json.dump(new_model_index, f, indent=2)


class PTQOnlyScaleSave(PTQSaveBase):
Expand Down
51 changes: 51 additions & 0 deletions angelslim/compressor/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,58 @@ def save(self, save_path: str):
save_func = self.quant_model.get_save_func()(self.quant_model)
save_func.save(save_path)

def get_meta_weights_info(self, model):
"""获取所有meta权重的详细信息"""
meta_params = []

for name, param in model.named_parameters():
if param.device.type == "meta":
meta_params.append(
{
"name": name,
}
)
return meta_params

def set_meta_weights_info(self, model):
"""替换所有meta权重"""
orign_w_dict = {}
for name, param in model.named_parameters():
if param.device.type == "meta":
with open(
os.path.join(self.absolute_model_path, "model.safetensors.index.json"),
"r",
) as f:
model_index = json.load(f)
orign_w_file = os.path.join(
self.absolute_model_path,
model_index["weight_map"][name],
)
if orign_w_file in orign_w_dict.keys():
orign_w = orign_w_dict[orign_w_file]
else:
orign_w = load_file(orign_w_file, device="cpu")
orign_w_dict[orign_w_file] = orign_w

empty_tensor = torch.empty(param.data.shape, dtype=param.data.dtype, device="cpu")
new_param = torch.nn.Parameter(empty_tensor)
new_param.data = orign_w[name]
parts = name.split(".")
current_module = model

# 导航到包含参数的模块
for part in parts[:-1]:
current_module = getattr(current_module, part)

# 设置新的参数
setattr(current_module, parts[-1], new_param)

del orign_w_dict

def _convert(self):
self.set_meta_weights_info(self.quant_model.model)
print_info(f"Meta weight:{self.get_meta_weights_info(self.quant_model.model)}")

# 1. get act, weight and kv-cache scale
for name, sub_layer in self.ptq_hook.quant_layers_dict.items():
if (
Expand Down
1 change: 0 additions & 1 deletion angelslim/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def prepare_model(
low_cpu_mem_usage=low_cpu_mem_usage,
use_cache=use_cache,
using_multi_nodes=using_multi_nodes,
attn_implementation=attn_implementation,
)
self.model_path = model_path
elif self.series in ["Omni"]:
Expand Down
1 change: 1 addition & 0 deletions angelslim/models/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .glm import GLM # noqa: F401
from .hunyuan_dense import HunyuanDense # noqa: F401
from .hunyuan_moe import HunyuanMoE # noqa: F401
from .hunyuan_v3_moe import HYV3MoE # noqa: F401
from .kimi_k2 import KimiK2 # noqa: F401
from .llama import Llama # noqa: F401
from .qwen import Qwen # noqa: F401
Expand Down
Loading
Loading