Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ def __init__(
self.obs_attn_weights = mk("attn_weights")
self.obs_attn_out_h = mk("attn_out_h")

# New kv delta
self.obs_new_k = mk("new_k") # (B, n_kv, 1, H)
self.obs_new_v = mk("new_v") # (B, n_kv, 1, H)

# Total KV after concat (used for matmul/attn)
self.obs_k_total = mk("k_total") # (B, max_seq, H)
self.obs_v_total = mk("v_total") # (B, max_seq, H)

def _rot(self, t: torch.Tensor, o_x1, o_x2, o_cat):
# t: (..., head_dim)
x1, x2 = torch.chunk(t, 2, dim=-1)
Expand Down Expand Up @@ -253,6 +261,8 @@ def forward(
v_i_past = past_v[:, kv_i, :, :]
k_i = torch.cat([k_i_past, k_i_new], dim=1)
v_i = torch.cat([v_i_past, v_i_new], dim=1)
k_i = self._fq(k_i, self.obs_k_total)
v_i = self._fq(v_i, self.obs_v_total)

for rep_i in range(self.kv_rep):
q_idx = kv_i * self.kv_rep + rep_i
Expand Down Expand Up @@ -301,6 +311,8 @@ def forward(
# new kv delta: (B, n_kv, 1, H)
new_k = torch.stack(new_k_parts, dim=1)
new_v = torch.stack(new_v_parts, dim=1)
new_k = self._fq(new_k, self.obs_new_k)
new_v = self._fq(new_v, self.obs_new_v)
new_key_value = (new_k, new_v)

if use_cache:
Expand Down Expand Up @@ -332,6 +344,10 @@ def _all_observers(self):
self.obs_attn_out,
self.obs_attn_weights,
self.obs_attn_out_h,
self.obs_new_k,
self.obs_new_v,
self.obs_k_total,
self.obs_v_total,
)
for m in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
yield from m._all_observers()