diff --git a/angelslim/compressor/quant/core/config.py b/angelslim/compressor/quant/core/config.py index 1ae16296..08cf00f9 100644 --- a/angelslim/compressor/quant/core/config.py +++ b/angelslim/compressor/quant/core/config.py @@ -66,6 +66,7 @@ def __init__(self, config, global_config=None): kv_cache_quant_method = quantization_args.quant_method.get("kv_cache", None) self.cpu_convert = quantization_args.cpu_convert self.save_name = quantization_args.save_name + self.quant_talker = getattr(quantization_args, "quant_talker", False) if global_config: self.max_seq_length = global_config.max_seq_length diff --git a/angelslim/data/dataloader.py b/angelslim/data/dataloader.py index 72488ce1..2bd4cf93 100644 --- a/angelslim/data/dataloader.py +++ b/angelslim/data/dataloader.py @@ -45,6 +45,7 @@ def create_data_loader( model_name: str = None, quantization_config: str = None, is_sft_data: bool = False, + dtype=None, ) -> DataLoader: """ Create appropriate DataLoader based on data source @@ -114,6 +115,7 @@ def create_data_loader( data_source=data_source, is_hf_dataset=not os.path.isfile(data_source), use_audio_in_video=use_audio_in_video, + dtype=dtype, ) elif data_type == "AudioDataset": dataset = AudioDataset( diff --git a/angelslim/data/omni_dataset.py b/angelslim/data/omni_dataset.py index 5ab67f64..8d2c2019 100644 --- a/angelslim/data/omni_dataset.py +++ b/angelslim/data/omni_dataset.py @@ -35,10 +35,12 @@ def __init__( data_source: Union[str, Dict] = None, is_hf_dataset: bool = False, use_audio_in_video: bool = False, + dtype=None, ): super().__init__(processor, device, max_length) self.is_hf_dataset = is_hf_dataset self.use_audio_in_video = use_audio_in_video + self.dtype = dtype self._load_file_based_dataset(data_source, num_samples) @@ -112,10 +114,11 @@ def _process_and_append(self, messages: List[Dict]): inputs = self.processor( text=text, images=images, - audios=audios, + audio=audios, videos=videos, padding=True, return_tensors="pt", use_audio_in_video=self.use_audio_in_video, ) + inputs = inputs.to(self.device).to(self.dtype) self.data.append(inputs) diff --git a/angelslim/engine.py b/angelslim/engine.py index 8c5344c2..23575f8c 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -160,6 +160,7 @@ def prepare_data( model_name=None, quantization_config=None, is_sft_data=False, + dtype=None, ) -> Optional[Any]: """Prepare compression dataset""" if custom_dataloader is not None: @@ -187,6 +188,7 @@ def prepare_data( model_name=model_name, quantization_config=quantization_config, is_sft_data=is_sft_data, + dtype=dtype, ) self.max_seq_length = max_length diff --git a/angelslim/models/omni/qwen3_omni.py b/angelslim/models/omni/qwen3_omni.py index 4dbdf933..cadc987f 100644 --- a/angelslim/models/omni/qwen3_omni.py +++ b/angelslim/models/omni/qwen3_omni.py @@ -19,10 +19,17 @@ AutoTokenizer, Qwen3OmniMoeForConditionalGeneration, ) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeTalkerTextExperts, + Qwen3OmniMoeTalkerTextTopKRouter, + Qwen3OmniMoeThinkerTextExperts, + Qwen3OmniMoeThinkerTextTopKRouter, +) from ...compressor.quant.core import PTQVLMSaveVllmHF -from ...utils import find_layers, print_info +from ...utils import find_layers, find_parent_layer_and_sub_name, print_info from ..base_model import BaseLLMModel +from ..llm.qwen import QwenMoeExpertsWithLinear from ..model_factory import SlimModelFactory @@ -38,7 +45,92 @@ def __init__( deploy_backend=deploy_backend, ) self.modal_type = "Omni" - self.block_name = ["thinker.model.layers", "talker.model.layers"] + self.thinker_block_name = "thinker.model.layers" + self.talker_block_name = "talker.model.layers" + self.observer_layer_classes = [ + torch.nn.Linear, + Qwen3OmniMoeThinkerTextTopKRouter, + Qwen3OmniMoeTalkerTextTopKRouter, + ] + self.observed_names = [ + "k_proj", + "v_proj", + "q_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + + def replace_moe(self): + for name, module in self.model.thinker.named_modules(): + if isinstance(module, Qwen3OmniMoeThinkerTextExperts) and not isinstance( + module, QwenMoeExpertsWithLinear + ): + print(name) + parent_layer, sub_name = find_parent_layer_and_sub_name(self.model.thinker, name) + moe_linear = QwenMoeExpertsWithLinear(module) + del module + setattr(parent_layer, sub_name, moe_linear) + + for name, module in self.model.talker.named_modules(): + if isinstance(module, Qwen3OmniMoeTalkerTextExperts) and not isinstance( + module, QwenMoeExpertsWithLinear + ): + print(name) + parent_layer, sub_name = find_parent_layer_and_sub_name(self.model.talker, name) + moe_linear = QwenMoeExpertsWithLinear(module) + del module + setattr(parent_layer, sub_name, moe_linear) + + def _patch_inputs_embeds_generate_device(self, module): + if module is None or getattr(module, "_angelslim_generate_device_patch", False): + return + + original_generate = module.generate + skip_keys = {"past_key_values", "encoder_outputs"} + + def move_to_target_device(value, target_device): + if isinstance(value, torch.Tensor): + if value.device.type == "meta" or value.device == target_device: + return value + return value.to(target_device) + if isinstance(value, tuple): + return tuple(move_to_target_device(item, target_device) for item in value) + if isinstance(value, list): + return [move_to_target_device(item, target_device) for item in value] + if isinstance(value, dict): + return { + key: item if key in skip_keys else move_to_target_device(item, target_device) + for key, item in value.items() + } + return value + + def generate_on_module_device(*args, **kwargs): + inputs_embeds = kwargs.get("inputs_embeds") + if inputs_embeds is not None: + target_device = getattr(module, "device", inputs_embeds.device) + if target_device.type == "meta": + target_device = inputs_embeds.device + + kwargs = { + key: value if key in skip_keys else move_to_target_device(value, target_device) + for key, value in kwargs.items() + } + + return original_generate(*args, **kwargs) + + module.generate = generate_on_module_device + module._angelslim_generate_device_patch = True + + def _patch_omni_generate_devices(self): + talker = getattr(self.model, "talker", None) + self._patch_inputs_embeds_generate_device(talker) + self._patch_inputs_embeds_generate_device(getattr(talker, "code_predictor", None)) + + def init_ptq(self, slim_config): + super().init_ptq(slim_config) + self.replace_moe() def from_pretrained( self, @@ -63,6 +155,7 @@ def from_pretrained( device_map=device_map, attn_implementation=attn_implementation, ) + self._patch_omni_generate_devices() # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( @@ -74,24 +167,21 @@ def from_pretrained( model_path, trust_remote_code=trust_remote_code ) - def get_observer_layers(self): - names = [ - "k_proj", - "v_proj", - "q_proj", - "o_proj", - "up_proj", - "gate_proj", - "down_proj", - ] + def _get_quant_block_names(self): + block_names = [self.thinker_block_name] + if getattr(self.quant_config, "quant_talker", True): + block_names.append(self.talker_block_name) + return block_names + def get_observer_layers(self): observer_layers_dict = {} layers_dict = find_layers(self.model, layers=self.observer_layer_classes) + block_names = self._get_quant_block_names() ignore_layers = self.skip_layer_names() for name, module in layers_dict.items(): - block_condition = any(name.startswith(block) for block in self.block_name) - if block_condition and name.split(".")[-1] in names: + block_condition = any(name.startswith(block) for block in block_names) + if block_condition and name.split(".")[-1] in self.observed_names: observer_layers_dict[name] = module else: ignore_layers.append(name) @@ -106,10 +196,11 @@ def get_observer_layers(self): def get_kvcache_observer_layers_names(self, observe_names): names = ["self_attn.k_proj", "self_attn.v_proj"] + block_names = self._get_quant_block_names() return [ k for k in observe_names - if any(k.startswith(block) for block in self.block_name) + if any(k.startswith(block) for block in block_names) and k.split(".")[-2] + "." + k.split(".")[-1] in names ] @@ -129,10 +220,9 @@ def model_forward(self, dataloader, **kwargs): if dataloader is not None: with torch.no_grad(): for batch in tqdm(dataloader, desc="calibrating...", total=len(dataloader)): - inputs = {k: v.to(device) for k, v in batch.items()} try: text_ids, audio = self.model.generate( - **inputs, use_audio_in_video=self.use_audio_in_video + **batch, use_audio_in_video=self.use_audio_in_video ) calibrated_cnt += 1 except ValueError: diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index 31b5acdc..67d881c4 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -200,6 +200,7 @@ class QuantizationConfig: quant_method: Algorithm used for quantization modules_to_quantize: List of module types to quantize ignore_layers: List of layer names to skip + quant_talker: Whether to quantize Qwen3-Omni talker LLM module """ name: str = field(default="fp8_dynamic") @@ -222,6 +223,7 @@ class QuantizationConfig: ignore_layers: List[str] = field(default_factory=list) quant_analyse: bool = field(default=False) quant_vit: bool = field(default=False) + quant_talker: bool = field(default=False) # DAQ-specific fields base_model_path: Optional[str] = field(default=None) base_is_fp8: bool = field(default=False) diff --git a/configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml b/configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml index 788e66a2..7d04fa32 100644 --- a/configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml +++ b/configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml @@ -23,3 +23,4 @@ compression: quant_method: weight: "per-tensor" activation: "per-tensor" + quant_talker: false # Whether to quantize Qwen3-Omni talker LLM module diff --git a/configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml b/configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml index 983766de..e6e4473b 100644 --- a/configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml +++ b/configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml @@ -23,6 +23,7 @@ compression: quant_method: weight: "per-tensor" activation: "per-tensor" + quant_talker: false # Whether to quantize Qwen3-Omni talker LLM module # Dataset for calibration dataset: diff --git a/dataset/omni_fake_data/audios/0.wav b/dataset/omni_fake_data/audios/0.wav new file mode 100644 index 00000000..8f0660a5 Binary files /dev/null and b/dataset/omni_fake_data/audios/0.wav differ diff --git a/dataset/omni_fake_data/fake_data.json b/dataset/omni_fake_data/fake_data.json index 885bced1..f7247835 100755 --- a/dataset/omni_fake_data/fake_data.json +++ b/dataset/omni_fake_data/fake_data.json @@ -1,3 +1,3 @@ -{"messages": [{"role": "user", "content": "What happens after the text disappears from the screen?"}], "video_path": "./videos/0.mp4"} -{"messages": [{"role": "user", "content": "How many food item is shown in the bar graph?"}], "image_path": "./images/0.png"} -{"messages": [{"role": "user", "content": "Why is the speech described as rich in frequency content?"}], "audio_path": "./audios/0.png"} \ No newline at end of file +{"messages": [{"role": "user", "content": "描述这个视频的内容。"}], "video_path": "./videos/0.mp4"} +{"messages": [{"role": "user", "content": "请描述这张图片的内容。"}], "image_path": "./images/0.png"} +{"messages": [{"role": "user", "content": "请将这段语音转写成文字。"}], "audio_path": "./audios/0.wav"} \ No newline at end of file diff --git a/docs/source/models/qwen3_omni/qwen3_omni_quant.md b/docs/source/models/qwen3_omni/qwen3_omni_quant.md index e04ac42f..e2983e26 100644 --- a/docs/source/models/qwen3_omni/qwen3_omni_quant.md +++ b/docs/source/models/qwen3_omni/qwen3_omni_quant.md @@ -5,7 +5,7 @@ Qwen3-Omni模型可采用**FP8(static、dynamic)** 方式进行模型压缩 ## FP8 量化(W8A8) -Qwen3-Omni的FP8量化采用**per-tensor粒度**,支持thinker与talker的llm模块的动态量化(dynamic)和静态量化(static)两种模式。 +Qwen3-Omni的FP8量化采用**per-tensor粒度**,支持thinker与talker的llm模块的动态量化(dynamic)和静态量化(static)两种模式;可通过`quantization.quant_talker`控制是否量化talker。 ### 配置参数说明 @@ -22,6 +22,7 @@ FP8量化的配置文件可参考路径:`configs/qwen3_omni/fp8_static` 和 `c - `quantization.name`:量化算法类型,根据需求选择`fp8_static`(静态量化)或`fp8_dynamic`(动态量化)。 - `quantization.bits`:量化比特数,FP8量化固定填写`8`。 - `quantization.quant_method`:权重量化粒度,FP8量化固定为`per-tensor`。 +- `quantization.quant_talker`:是否量化Qwen3-Omni的talker LLM模块,默认`false`;设为`false`时仅量化thinker。 #### dataset配置 - `name`:数据集类型,固定选择`OmniDataset`。 @@ -51,9 +52,14 @@ python3 tools/run.py -c configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml ## 模型部署 -vLLM框架支持Qwen3-Omni的FP8(per-tensor)量化模型部署,建议使用官方部署方式: ### vllm -参考[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#vllm-usage](URL) +vLLM框架支持Qwen3-Omni的FP8(per-tensor)量化模型部署,只部署Thinker,可参考以下命令: + +```shell +vllm serve $model_path --port 8021 -tp 4 --gpu-memory-utilization 0.8 +``` + +要部署完整的Qwen3-Omni,参考vllm-omni的部署方式:[https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/examples/online_serving/qwen3_omni/](URL) ### transformers -参考[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#transformers-usage](URL) \ No newline at end of file +参考:[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#transformers-usage](URL) \ No newline at end of file diff --git a/tools/run.py b/tools/run.py index 88d84451..f0919de3 100644 --- a/tools/run.py +++ b/tools/run.py @@ -113,6 +113,7 @@ def multi_nodes_run(config): inference_settings=dataset_config.inference_settings, use_audio_in_video=model_config.use_audio_in_video, is_sft_data=dataset_config.is_sft_data, + dtype=slim_engine.slim_model.model.dtype, ) # Step 6: Initialize compressor @@ -364,6 +365,7 @@ def run(config): model_name=model_config.name, quantization_config=compress_config.quantization, is_sft_data=dataset_config.is_sft_data, + dtype=slim_engine.slim_model.model.dtype, ) # Step 5: Initialize compressor