Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tico/quantization/evaluation/script/llm_tasks_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def evaluate_llm_on_tasks(
) -> dict[str, Any]:
if hasattr(model, "wrapped"):
model = model.wrapped

# use acclerated version for evaluation
for module in model.modules():
if hasattr(module, "use_cuda_accelerated_version_for_evaluation"):
module.use_cuda_accelerated_version_for_evaluation = True

model_to_evaluate = HFLM(
model,
"causal",
Expand All @@ -38,7 +44,14 @@ def evaluate_llm_on_tasks(
truncation=True,
)
tasks_list: list[str] = tasks.split(",")
return evaluator.simple_evaluate(model_to_evaluate, tasks=tasks_list)
result = evaluator.simple_evaluate(model_to_evaluate, tasks=tasks_list)

# cancel usage of accelerated version for evaluation
for module in model.modules():
if hasattr(module, "use_cuda_accelerated_version_for_evaluation"):
module.use_cuda_accelerated_version_for_evaluation = False

return result


def main():
Expand Down
74 changes: 74 additions & 0 deletions tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def __init__(
mask.triu_(1)
self.register_buffer("causal_mask_template", mask, persistent=False)

self.use_cuda_accelerated_version_for_evaluation = False

def _rot(self, t, o_x1, o_x2, o_cat):
x1, x2 = torch.chunk(t, 2, dim=-1)
x1 = self._fq(x1, o_x1)
Expand Down Expand Up @@ -171,6 +173,72 @@ def _apply_rope(

return t_rot

def cuda_accelerated_forward_for_evaluation(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
):
assert position_embeddings is not None
assert attention_mask is not None

hidden = self._fq(hidden_states, self.obs_hidden)
B, S, _ = hidden.shape
H = self.head_dim

q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H)
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, K, H)
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, K, H)

# Rope tables
cos, sin = position_embeddings

# Repeat kv heads to match query heads (GQA)
k = k.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
v = v.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)

k = self._apply_rope(
k,
cos,
sin,
self.obs_q_x1,
self.obs_q_x2,
self.obs_q_cat,
self.obs_q_cos,
self.obs_q_sin,
self.obs_q_rot,
Comment on lines +204 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.obs_q_x1,
self.obs_q_x2,
self.obs_q_cat,
self.obs_q_cos,
self.obs_q_sin,
self.obs_q_rot,
self.obs_k_x1,
self.obs_k_x2,
self.obs_k_cat,
self.obs_k_cos,
self.obs_k_sin,
self.obs_k_rot,

)
q = self._apply_rope(
q,
cos,
sin,
self.obs_q_x1,
self.obs_q_x2,
self.obs_q_cat,
self.obs_q_cos,
self.obs_q_sin,
self.obs_q_rot,
)

# Attention logits: q @ k^T
logits = self._fq(q @ k.transpose(-2, -1), self.obs_logits)
logits = self._fq(logits + attention_mask, self.obs_mask_add)

# Softmax
attn_weights = torch.softmax(logits, -1, dtype=torch.float32).to(q.dtype)
attn_weights = self._fq(attn_weights, self.obs_softmax)

# Attention output
attn_out = (
self._fq(attn_weights @ v, self.obs_attn_out)
.transpose(1, 2)
.reshape(B, S, -1)
) # (B, S, n_h * H)`
# Final projection
out = self.o_proj(attn_out)

return out, attn_weights

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -181,6 +249,12 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
if self.use_cuda_accelerated_version_for_evaluation is True:
assert past_key_value is None # cache is not supported currently
return self.cuda_accelerated_forward_for_evaluation(
hidden_states, position_embeddings, attention_mask
)

hidden = self._fq(hidden_states, self.obs_hidden)
B, S, _ = hidden.shape
H = self.head_dim
Expand Down