Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions angelslim/compressor/quant/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, config, global_config=None):
kv_cache_quant_method = quantization_args.quant_method.get("kv_cache", None)
self.cpu_convert = quantization_args.cpu_convert
self.save_name = quantization_args.save_name
self.quant_talker = getattr(quantization_args, "quant_talker", False)

if global_config:
self.max_seq_length = global_config.max_seq_length
Expand Down
2 changes: 2 additions & 0 deletions angelslim/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def create_data_loader(
model_name: str = None,
quantization_config: str = None,
is_sft_data: bool = False,
dtype=None,
) -> DataLoader:
"""
Create appropriate DataLoader based on data source
Expand Down Expand Up @@ -114,6 +115,7 @@ def create_data_loader(
data_source=data_source,
is_hf_dataset=not os.path.isfile(data_source),
use_audio_in_video=use_audio_in_video,
dtype=dtype,
)
elif data_type == "AudioDataset":
dataset = AudioDataset(
Expand Down
5 changes: 4 additions & 1 deletion angelslim/data/omni_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def __init__(
data_source: Union[str, Dict] = None,
is_hf_dataset: bool = False,
use_audio_in_video: bool = False,
dtype=None,
):
super().__init__(processor, device, max_length)
self.is_hf_dataset = is_hf_dataset
self.use_audio_in_video = use_audio_in_video
self.dtype = dtype

self._load_file_based_dataset(data_source, num_samples)

Expand Down Expand Up @@ -112,10 +114,11 @@ def _process_and_append(self, messages: List[Dict]):
inputs = self.processor(
text=text,
images=images,
audios=audios,
audio=audios,
videos=videos,
padding=True,
return_tensors="pt",
use_audio_in_video=self.use_audio_in_video,
)
inputs = inputs.to(self.device).to(self.dtype)
self.data.append(inputs)
2 changes: 2 additions & 0 deletions angelslim/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def prepare_data(
model_name=None,
quantization_config=None,
is_sft_data=False,
dtype=None,
) -> Optional[Any]:
"""Prepare compression dataset"""
if custom_dataloader is not None:
Expand Down Expand Up @@ -187,6 +188,7 @@ def prepare_data(
model_name=model_name,
quantization_config=quantization_config,
is_sft_data=is_sft_data,
dtype=dtype,
)
self.max_seq_length = max_length

Expand Down
124 changes: 107 additions & 17 deletions angelslim/models/omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
AutoTokenizer,
Qwen3OmniMoeForConditionalGeneration,
)
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
Qwen3OmniMoeTalkerTextExperts,
Qwen3OmniMoeTalkerTextTopKRouter,
Qwen3OmniMoeThinkerTextExperts,
Qwen3OmniMoeThinkerTextTopKRouter,
)

from ...compressor.quant.core import PTQVLMSaveVllmHF
from ...utils import find_layers, print_info
from ...utils import find_layers, find_parent_layer_and_sub_name, print_info
from ..base_model import BaseLLMModel
from ..llm.qwen import QwenMoeExpertsWithLinear
from ..model_factory import SlimModelFactory


Expand All @@ -38,7 +45,92 @@ def __init__(
deploy_backend=deploy_backend,
)
self.modal_type = "Omni"
self.block_name = ["thinker.model.layers", "talker.model.layers"]
self.thinker_block_name = "thinker.model.layers"
self.talker_block_name = "talker.model.layers"
self.observer_layer_classes = [
torch.nn.Linear,
Qwen3OmniMoeThinkerTextTopKRouter,
Qwen3OmniMoeTalkerTextTopKRouter,
]
self.observed_names = [
"k_proj",
"v_proj",
"q_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]

def replace_moe(self):
for name, module in self.model.thinker.named_modules():
if isinstance(module, Qwen3OmniMoeThinkerTextExperts) and not isinstance(
module, QwenMoeExpertsWithLinear
):
print(name)
parent_layer, sub_name = find_parent_layer_and_sub_name(self.model.thinker, name)
moe_linear = QwenMoeExpertsWithLinear(module)
del module
setattr(parent_layer, sub_name, moe_linear)

for name, module in self.model.talker.named_modules():
if isinstance(module, Qwen3OmniMoeTalkerTextExperts) and not isinstance(
module, QwenMoeExpertsWithLinear
):
print(name)
parent_layer, sub_name = find_parent_layer_and_sub_name(self.model.talker, name)
moe_linear = QwenMoeExpertsWithLinear(module)
del module
setattr(parent_layer, sub_name, moe_linear)

def _patch_inputs_embeds_generate_device(self, module):
if module is None or getattr(module, "_angelslim_generate_device_patch", False):
return

original_generate = module.generate
skip_keys = {"past_key_values", "encoder_outputs"}

def move_to_target_device(value, target_device):
if isinstance(value, torch.Tensor):
if value.device.type == "meta" or value.device == target_device:
return value
return value.to(target_device)
if isinstance(value, tuple):
return tuple(move_to_target_device(item, target_device) for item in value)
if isinstance(value, list):
return [move_to_target_device(item, target_device) for item in value]
if isinstance(value, dict):
return {
key: item if key in skip_keys else move_to_target_device(item, target_device)
for key, item in value.items()
}
return value

def generate_on_module_device(*args, **kwargs):
inputs_embeds = kwargs.get("inputs_embeds")
if inputs_embeds is not None:
target_device = getattr(module, "device", inputs_embeds.device)
if target_device.type == "meta":
target_device = inputs_embeds.device

kwargs = {
key: value if key in skip_keys else move_to_target_device(value, target_device)
for key, value in kwargs.items()
}

return original_generate(*args, **kwargs)

module.generate = generate_on_module_device
module._angelslim_generate_device_patch = True

def _patch_omni_generate_devices(self):
talker = getattr(self.model, "talker", None)
self._patch_inputs_embeds_generate_device(talker)
self._patch_inputs_embeds_generate_device(getattr(talker, "code_predictor", None))

def init_ptq(self, slim_config):
super().init_ptq(slim_config)
self.replace_moe()

def from_pretrained(
self,
Expand All @@ -63,6 +155,7 @@ def from_pretrained(
device_map=device_map,
attn_implementation=attn_implementation,
)
self._patch_omni_generate_devices()

# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -74,24 +167,21 @@ def from_pretrained(
model_path, trust_remote_code=trust_remote_code
)

def get_observer_layers(self):
names = [
"k_proj",
"v_proj",
"q_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
]
def _get_quant_block_names(self):
block_names = [self.thinker_block_name]
if getattr(self.quant_config, "quant_talker", True):
block_names.append(self.talker_block_name)
return block_names

def get_observer_layers(self):
observer_layers_dict = {}
layers_dict = find_layers(self.model, layers=self.observer_layer_classes)
block_names = self._get_quant_block_names()

ignore_layers = self.skip_layer_names()
for name, module in layers_dict.items():
block_condition = any(name.startswith(block) for block in self.block_name)
if block_condition and name.split(".")[-1] in names:
block_condition = any(name.startswith(block) for block in block_names)
if block_condition and name.split(".")[-1] in self.observed_names:
observer_layers_dict[name] = module
else:
ignore_layers.append(name)
Expand All @@ -106,10 +196,11 @@ def get_observer_layers(self):

def get_kvcache_observer_layers_names(self, observe_names):
names = ["self_attn.k_proj", "self_attn.v_proj"]
block_names = self._get_quant_block_names()
return [
k
for k in observe_names
if any(k.startswith(block) for block in self.block_name)
if any(k.startswith(block) for block in block_names)
and k.split(".")[-2] + "." + k.split(".")[-1] in names
]

Expand All @@ -129,10 +220,9 @@ def model_forward(self, dataloader, **kwargs):
if dataloader is not None:
with torch.no_grad():
for batch in tqdm(dataloader, desc="calibrating...", total=len(dataloader)):
inputs = {k: v.to(device) for k, v in batch.items()}
try:
text_ids, audio = self.model.generate(
**inputs, use_audio_in_video=self.use_audio_in_video
**batch, use_audio_in_video=self.use_audio_in_video
)
calibrated_cnt += 1
except ValueError:
Expand Down
2 changes: 2 additions & 0 deletions angelslim/utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class QuantizationConfig:
quant_method: Algorithm used for quantization
modules_to_quantize: List of module types to quantize
ignore_layers: List of layer names to skip
quant_talker: Whether to quantize Qwen3-Omni talker LLM module
"""

name: str = field(default="fp8_dynamic")
Expand All @@ -222,6 +223,7 @@ class QuantizationConfig:
ignore_layers: List[str] = field(default_factory=list)
quant_analyse: bool = field(default=False)
quant_vit: bool = field(default=False)
quant_talker: bool = field(default=False)
# DAQ-specific fields
base_model_path: Optional[str] = field(default=None)
base_is_fp8: bool = field(default=False)
Expand Down
1 change: 1 addition & 0 deletions configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ compression:
quant_method:
weight: "per-tensor"
activation: "per-tensor"
quant_talker: false # Whether to quantize Qwen3-Omni talker LLM module
1 change: 1 addition & 0 deletions configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ compression:
quant_method:
weight: "per-tensor"
activation: "per-tensor"
quant_talker: false # Whether to quantize Qwen3-Omni talker LLM module

# Dataset for calibration
dataset:
Expand Down
Binary file added dataset/omni_fake_data/audios/0.wav
Binary file not shown.
6 changes: 3 additions & 3 deletions dataset/omni_fake_data/fake_data.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{"messages": [{"role": "user", "content": "What happens after the text disappears from the screen?"}], "video_path": "./videos/0.mp4"}
{"messages": [{"role": "user", "content": "How many food item is shown in the bar graph?"}], "image_path": "./images/0.png"}
{"messages": [{"role": "user", "content": "Why is the speech described as rich in frequency content?"}], "audio_path": "./audios/0.png"}
{"messages": [{"role": "user", "content": "描述这个视频的内容。"}], "video_path": "./videos/0.mp4"}
{"messages": [{"role": "user", "content": "请描述这张图片的内容。"}], "image_path": "./images/0.png"}
{"messages": [{"role": "user", "content": "请将这段语音转写成文字。"}], "audio_path": "./audios/0.wav"}
14 changes: 10 additions & 4 deletions docs/source/models/qwen3_omni/qwen3_omni_quant.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Qwen3-Omni模型可采用**FP8(static、dynamic)** 方式进行模型压缩

## FP8 量化(W8A8)

Qwen3-Omni的FP8量化采用**per-tensor粒度**,支持thinker与talker的llm模块的动态量化(dynamic)和静态量化(static)两种模式。
Qwen3-Omni的FP8量化采用**per-tensor粒度**,支持thinker与talker的llm模块的动态量化(dynamic)和静态量化(static)两种模式;可通过`quantization.quant_talker`控制是否量化talker

### 配置参数说明

Expand All @@ -22,6 +22,7 @@ FP8量化的配置文件可参考路径:`configs/qwen3_omni/fp8_static` 和 `c
- `quantization.name`:量化算法类型,根据需求选择`fp8_static`(静态量化)或`fp8_dynamic`(动态量化)。
- `quantization.bits`:量化比特数,FP8量化固定填写`8`。
- `quantization.quant_method`:权重量化粒度,FP8量化固定为`per-tensor`。
- `quantization.quant_talker`:是否量化Qwen3-Omni的talker LLM模块,默认`false`;设为`false`时仅量化thinker。

#### dataset配置
- `name`:数据集类型,固定选择`OmniDataset`。
Expand Down Expand Up @@ -51,9 +52,14 @@ python3 tools/run.py -c configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml

## 模型部署

vLLM框架支持Qwen3-Omni的FP8(per-tensor)量化模型部署,建议使用官方部署方式:
### vllm
参考[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#vllm-usage](URL)
vLLM框架支持Qwen3-Omni的FP8(per-tensor)量化模型部署,只部署Thinker,可参考以下命令:

```shell
vllm serve $model_path --port 8021 -tp 4 --gpu-memory-utilization 0.8
```

要部署完整的Qwen3-Omni,参考vllm-omni的部署方式:[https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/examples/online_serving/qwen3_omni/](URL)

### transformers
参考[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#transformers-usage](URL)
参考[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#transformers-usage](URL)
2 changes: 2 additions & 0 deletions tools/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def multi_nodes_run(config):
inference_settings=dataset_config.inference_settings,
use_audio_in_video=model_config.use_audio_in_video,
is_sft_data=dataset_config.is_sft_data,
dtype=slim_engine.slim_model.model.dtype,
)

# Step 6: Initialize compressor
Expand Down Expand Up @@ -364,6 +365,7 @@ def run(config):
model_name=model_config.name,
quantization_config=compress_config.quantization,
is_sft_data=dataset_config.is_sft_data,
dtype=slim_engine.slim_model.model.dtype,
)

# Step 5: Initialize compressor
Expand Down
Loading