From 1a75ef9ad7a223d3e66165d5d345378650c93f02 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Fri, 6 Mar 2026 12:28:28 +0300 Subject: [PATCH] [quantization] Use cuda-accelerated `forward` This PR speeds-up quantized model evaluation using `forward` method accelerated for cuda. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../evaluation/script/llm_tasks_eval.py | 15 +++- .../wrappers/llama/quant_attn_prefill.py | 74 +++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/tico/quantization/evaluation/script/llm_tasks_eval.py b/tico/quantization/evaluation/script/llm_tasks_eval.py index a11fa73c..82e27bc4 100644 --- a/tico/quantization/evaluation/script/llm_tasks_eval.py +++ b/tico/quantization/evaluation/script/llm_tasks_eval.py @@ -30,6 +30,12 @@ def evaluate_llm_on_tasks( ) -> dict[str, Any]: if hasattr(model, "wrapped"): model = model.wrapped + + # use acclerated version for evaluation + for module in model.modules(): + if hasattr(module, "use_cuda_accelerated_version_for_evaluation"): + module.use_cuda_accelerated_version_for_evaluation = True + model_to_evaluate = HFLM( model, "causal", @@ -38,7 +44,14 @@ def evaluate_llm_on_tasks( truncation=True, ) tasks_list: list[str] = tasks.split(",") - return evaluator.simple_evaluate(model_to_evaluate, tasks=tasks_list) + result = evaluator.simple_evaluate(model_to_evaluate, tasks=tasks_list) + + # cancel usage of accelerated version for evaluation + for module in model.modules(): + if hasattr(module, "use_cuda_accelerated_version_for_evaluation"): + module.use_cuda_accelerated_version_for_evaluation = False + + return result def main(): diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py index cfc41c58..d6f089be 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py @@ -139,6 +139,8 @@ def __init__( mask.triu_(1) self.register_buffer("causal_mask_template", mask, persistent=False) + self.use_cuda_accelerated_version_for_evaluation = False + def _rot(self, t, o_x1, o_x2, o_cat): x1, x2 = torch.chunk(t, 2, dim=-1) x1 = self._fq(x1, o_x1) @@ -171,6 +173,72 @@ def _apply_rope( return t_rot + def cuda_accelerated_forward_for_evaluation( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ): + assert position_embeddings is not None + assert attention_mask is not None + + hidden = self._fq(hidden_states, self.obs_hidden) + B, S, _ = hidden.shape + H = self.head_dim + + q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H) + k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, K, H) + v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, K, H) + + # Rope tables + cos, sin = position_embeddings + + # Repeat kv heads to match query heads (GQA) + k = k.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H) + v = v.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H) + + k = self._apply_rope( + k, + cos, + sin, + self.obs_q_x1, + self.obs_q_x2, + self.obs_q_cat, + self.obs_q_cos, + self.obs_q_sin, + self.obs_q_rot, + ) + q = self._apply_rope( + q, + cos, + sin, + self.obs_q_x1, + self.obs_q_x2, + self.obs_q_cat, + self.obs_q_cos, + self.obs_q_sin, + self.obs_q_rot, + ) + + # Attention logits: q @ k^T + logits = self._fq(q @ k.transpose(-2, -1), self.obs_logits) + logits = self._fq(logits + attention_mask, self.obs_mask_add) + + # Softmax + attn_weights = torch.softmax(logits, -1, dtype=torch.float32).to(q.dtype) + attn_weights = self._fq(attn_weights, self.obs_softmax) + + # Attention output + attn_out = ( + self._fq(attn_weights @ v, self.obs_attn_out) + .transpose(1, 2) + .reshape(B, S, -1) + ) # (B, S, n_h * H)` + # Final projection + out = self.o_proj(attn_out) + + return out, attn_weights + def forward( self, hidden_states: torch.Tensor, @@ -181,6 +249,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ): + if self.use_cuda_accelerated_version_for_evaluation is True: + assert past_key_value is None # cache is not supported currently + return self.cuda_accelerated_forward_for_evaluation( + hidden_states, position_embeddings, attention_mask + ) + hidden = self._fq(hidden_states, self.obs_hidden) B, S, _ = hidden.shape H = self.head_dim