Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions src/pruna/algorithms/moe_kernel_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,32 @@ def get_hyperparameters(self) -> list:
default_value=8,
meta={"desc": "Maximum (log) block size for tiling through intermediate dimension."},
),
OrdinalHyperparameter(
"block_quant_shape_n",
sequence=[32, 64, 128, 256, 512, 1024, 2048, 4096, None],
default_value=None,
meta={
"desc": (
"Block size for quantization along the N dimension when weight_dtype is "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the N dimension? What is the Block size in this case? This is not a paramter automatically tuned? We can probably clarify the doc here :)

"fp8_w8a8. Must be set together with block_quant_shape_k: either both "
"None or both an integer (mixing None with a value is invalid). "
"Default None: no block-wise quant tiling."
)
},
),
OrdinalHyperparameter(
"block_quant_shape_k",
sequence=[32, 64, 128, 256, 512, 1024, 2048, 4096, None],
default_value=None,
meta={
"desc": (
"Block size for quantization along the K (intermediate) dimension when "
"weight_dtype is fp8_w8a8. Must be set together with block_quant_shape_n: "
"either both None or both an integer (mixing None with a value is invalid). "
"Default None: no block-wise quant tiling."
)
},
),
]

def model_check_fn(self, model: Any) -> bool:
Expand Down Expand Up @@ -178,6 +204,19 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
model_config = getattr(model, "config", None)
if model_config is None:
raise ValueError(f"Model {model.__class__.__name__} has no config.")
# Multimodal MoE (e.g. Qwen3_5MoeForConditionalGeneration): MoE parameters live on text_config.
if getattr(model_config, "num_experts", None) is None:
text_cfg = getattr(model_config, "text_config", None)
if text_cfg is not None and getattr(text_cfg, "num_experts", None) is not None:
model_config = text_cfg
Comment thread
llcnt marked this conversation as resolved.
else:
raise ValueError(
f"Cannot resolve MoE layout for {model.__class__.__name__}: "
"`config.num_experts` is missing and no `config.text_config.num_experts` was "
"found. This tuner expects a supported MoE model (e.g. Mixtral, Qwen-MoE, "
"or multimodal MoE with experts in `text_config`). For custom or future "
"architectures, extend config resolution in MoeKernelTuner._apply."
)

tensor_parallel_size = int(smash_config["tensor_parallel_size"])
if model.__class__.__name__ == "HunyuanImage3ForCausalMM":
Expand All @@ -194,6 +233,17 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16
use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8"
use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16"
block_quant_shape_n = smash_config["block_quant_shape_n"]
block_quant_shape_k = smash_config["block_quant_shape_k"]
if (block_quant_shape_n is None) ^ (block_quant_shape_k is None):
raise ValueError(
"block_quant_shape_n and block_quant_shape_k must both be None (default, "
"per-expert FP8 scales and no quant-block filtering) or both set to an integer; "
"setting only one 'None' is not supported."
)
block_quant_shape = None
if block_quant_shape_n is not None and block_quant_shape_k is not None:
block_quant_shape = [block_quant_shape_n, block_quant_shape_k]

# (iii) Tune the kernel over a range of batch sizes (single GPU per Ray worker).
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
Expand All @@ -206,6 +256,18 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:

ray.init(ignore_reinit_error=True)
search_space = get_configs_compute_bound(smash_config)

# Remove configs incompatible with block quantisation constraints:
# - BLOCK_SIZE_K must be divisible by block_quant_shape_k
# - BLOCK_SIZE_N must be divisible by block_quant_shape_n
if block_quant_shape is not None and use_fp8_w8a8:
search_space = [
cfg
for cfg in search_space
if cfg["BLOCK_SIZE_K"] % block_quant_shape_k == 0
and cfg["BLOCK_SIZE_N"] % block_quant_shape_n == 0
]

pruna_logger.info(f"Start tuning over {len(search_space)} configurations...")

start = time.time()
Expand All @@ -226,7 +288,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
use_fp8_w8a8,
use_int8_w8a16,
search_space,
None,
block_quant_shape,
False,
imported_packages,
0, # fixed seed for reproducibility
Expand Down Expand Up @@ -266,7 +328,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
dtype,
use_fp8_w8a8,
use_int8_w8a16,
None,
block_quant_shape,
smash_config["path_to_huggingface_hub_cache"],
smash_config["path_to_vllm_cache"],
imported_packages,
Expand Down
12 changes: 10 additions & 2 deletions src/pruna/engine/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def is_moe_lm(model: Any) -> bool:
"""
Check if the model is a MoE LM.

Currently all MoE LMs are based on Mixtral in transformers.
Detects MoE via ``config.num_experts`` (e.g. Mixtral, Qwen-MoE text-only)
or via nested ``config.text_config.num_experts`` (e.g. multimodal
``*ForConditionalGeneration`` wrappers).

Parameters
----------
Expand All @@ -121,7 +123,13 @@ def is_moe_lm(model: Any) -> bool:
bool
True if the model is a MoE LM, False otherwise.
"""
return hasattr(getattr(model, "config", None), "num_experts")
config = getattr(model, "config", None)
if config is None:
return False
if getattr(config, "num_experts", None) is not None:
return True
text_cfg = getattr(config, "text_config", None)
return text_cfg is not None and getattr(text_cfg, "num_experts", None) is not None


def is_transformers_pipeline_with_causal_lm(model: Any) -> bool:
Expand Down
Loading