-
Notifications
You must be signed in to change notification settings - Fork 25
Confirming prefix reparameterization implementation in ParScale #10
Description
Hi, thanks for sharing ParScale! I have a question about the prefix reparameterization implementation.
I’m currently experimenting with ParScale and implemented the reparameterization as described in your supplementary materials. Could you help me confirm if my implementation aligns with your intended design?
Here’s the core part of my code:
p_prime_dim = getattr(config, "parscale_p_prime_dim", 128)
mlp_hidden_dim = getattr(config, "parscale_mlp_hidden_dim", 512)
prefix_out_dim = config.num_key_value_heads * config.parscale_n_tokens * self.head_dim * 2
self.p_prime = nn.Parameter(torch.randn(config.parscale_n, p_prime_dim))
self.p_mlp = nn.Sequential(
nn.Linear(p_prime_dim, mlp_hidden_dim),
nn.Tanh(),
nn.Linear(mlp_hidden_dim, prefix_out_dim),
)
def generate_prefix(self):
P_all = self.p_mlp(self.p_prime)
P_all = P_all.view(
self.config.parscale_n,
self.config.num_key_value_heads,
self.config.parscale_n_tokens,
self.head_dim,
2 # k/v
)
return P_all[..., 0], P_all[..., 1] # return prefix_k, prefix_vDuring forward pass, I apply it like this:
if self.parscale_n >= 1:
inputs_embeds = repeat(inputs_embeds, "b s h -> (n_parscale b) s h", n_parscale=self.parscale_n)
if attention_mask is not None:
attention_mask = repeat(attention_mask, "b s -> (n_parscale b) s", n_parscale=self.parscale_n)
if position_ids is not None:
position_ids = repeat(position_ids, "b s -> (n_parscale b) s", n_parscale=self.parscale_n)
if past_key_values is None or past_key_values.get_seq_length() == 0:
prefix_k_list, prefix_v_list = [], []
for layer in self.layers:
pk, pv = layer.self_attn.generate_prefix()
prefix_k_list.append(pk)
prefix_v_list.append(pv)
past_key_values = ParscaleCache(prefix_k_list, prefix_v_list)Does this match the intended reparameterization strategy from your paper? Especially regarding how p_prime and p_mlp are used to generate the key/value prefixes.
Also, as a follow-up:
In your experiments, did you find that removing the reparameterization (i.e., directly learning the prefix weights without the MLP) made convergence harder?
Thanks a lot for your time and for this great work! 🙏