From 27d12332b8d32ba08887e2a8c1ce3ca52885a18b Mon Sep 17 00:00:00 2001 From: seongwoo Date: Mon, 9 Mar 2026 13:22:25 +0900 Subject: [PATCH] [quantization] Quantize cache This commit adds missing observers for kv cache. TICO-DCO-1.0-Signed-off-by: seongwoo --- .../wrapq/wrappers/llama/quant_attn_decode.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py b/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py index bac51fe3..01d606a6 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py @@ -156,6 +156,14 @@ def __init__( self.obs_attn_weights = mk("attn_weights") self.obs_attn_out_h = mk("attn_out_h") + # New kv delta + self.obs_new_k = mk("new_k") # (B, n_kv, 1, H) + self.obs_new_v = mk("new_v") # (B, n_kv, 1, H) + + # Total KV after concat (used for matmul/attn) + self.obs_k_total = mk("k_total") # (B, max_seq, H) + self.obs_v_total = mk("v_total") # (B, max_seq, H) + def _rot(self, t: torch.Tensor, o_x1, o_x2, o_cat): # t: (..., head_dim) x1, x2 = torch.chunk(t, 2, dim=-1) @@ -253,6 +261,8 @@ def forward( v_i_past = past_v[:, kv_i, :, :] k_i = torch.cat([k_i_past, k_i_new], dim=1) v_i = torch.cat([v_i_past, v_i_new], dim=1) + k_i = self._fq(k_i, self.obs_k_total) + v_i = self._fq(v_i, self.obs_v_total) for rep_i in range(self.kv_rep): q_idx = kv_i * self.kv_rep + rep_i @@ -301,6 +311,8 @@ def forward( # new kv delta: (B, n_kv, 1, H) new_k = torch.stack(new_k_parts, dim=1) new_v = torch.stack(new_v_parts, dim=1) + new_k = self._fq(new_k, self.obs_new_k) + new_v = self._fq(new_v, self.obs_new_v) new_key_value = (new_k, new_v) if use_cache: @@ -332,6 +344,10 @@ def _all_observers(self): self.obs_attn_out, self.obs_attn_weights, self.obs_attn_out_h, + self.obs_new_k, + self.obs_new_v, + self.obs_k_total, + self.obs_v_total, ) for m in (self.q_proj, self.k_proj, self.v_proj, self.o_proj): yield from m._all_observers()