Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

MODEL_NAME = "Maykeye/TinyLLama-v0"
MAX_SEQ = 256
static_data_inside_the_layer = True

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype="float32")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, legacy=False)
Expand Down Expand Up @@ -139,10 +140,26 @@ def make_fixed_inputs(prompt: str):
example_hidden = torch.randn(B, S, D)

with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(
qlayer,
(example_hidden,),
)
if static_data_inside_the_layer is True:
cm = tico.convert(
qlayer,
(example_hidden,),
)
else:
qattn = qlayer.wrapped
attn_mask = qattn._slice_causal(S, "cpu").squeeze(0)
dtype = example_hidden.dtype
pos_embeds = (
qattn.rope_cos_template.cpu().to(dtype),
qattn.rope_sin_template.cpu().to(dtype),
)

cm = tico.convert(
qlayer,
(example_hidden,),
kwargs={"attention_mask": attn_mask, "position_embeddings": pos_embeds},
)

# Note that the model is not fully quantized.
cm.save(save_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def forward(
L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)
attention_mask = attention_mask.squeeze(0)
attention_mask = self._fq(
attention_mask, self.obs_causal_mask
) # let it be quantized immediately
attention_mask = self._fq(
attention_mask, self.obs_causal_mask
) # let it be quantized immediately

if position_embeddings is None:
position_embeddings = (
Expand All @@ -206,11 +206,11 @@ def forward(
dtype=hidden_states.dtype, device=hidden_states.device
),
)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos, self.obs_cos),
self._fq(sin, self.obs_sin),
)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos, self.obs_cos),
self._fq(sin, self.obs_sin),
)

attn_out = self.self_attn(
hidden_states=hidden_states,
Expand Down