Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions modeling_qwen2_parscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from .configuration_qwen2_parscale import Qwen2ParScaleConfig
from typing import Any, Dict, List, Optional, Tuple, Union


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
Expand Down Expand Up @@ -138,6 +137,7 @@ def __init__(self, prefix_k, prefix_v) -> None:
self.value_cache: List[torch.Tensor] = prefix_v
self.parscale_n = prefix_k[0].size(0)
self.n_prefix_tokens = prefix_k[0].size(2)

def update(
self,
key_states: torch.Tensor,
Expand Down Expand Up @@ -191,6 +191,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[torch.bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
Expand All @@ -203,30 +204,44 @@ def forward(
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
if self.training:
assert not use_cache, "use_cache must be False during training"
assert past_key_value is None, "past_key_value must be False during training"

if past_key_value is not None and use_cache:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if self.config.parscale_n > 1:
else:
if self.config.parscale_n > 1:
B = key_states.size(0)
prefix_k = repeat(self.prefix_k, 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.config.parscale_n)
prefix_v = repeat(self.prefix_v, 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.config.parscale_n)

key_states = torch.cat([prefix_k, key_states], dim=2)
value_states = torch.cat([prefix_v, value_states], dim=2)

if self.config.parscale_n > 1:
# Expand attention mask to contain the prefix tokens
n_virtual_tokens = self.config.parscale_n_tokens

if attention_mask is not None:
attention_mask = torch.cat([
torch.zeros((attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], self.config.parscale_n_tokens), dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask
], dim=3)

if query_states.size(2) != 1:
query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2)
if attention_mask is not None:
if attention_mask.dim() == 2:
# [B, T] -> [B, T + n_virtual_tokens]
B, T = attention_mask.shape
# Create prefix mask: all ones for virtual tokens
prefix_mask = torch.ones(B, n_virtual_tokens, dtype=attention_mask.dtype, device=attention_mask.device)
# Concatenate prefix mask and original mask
attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
elif attention_mask.dim() == 4:
attention_mask = torch.cat([
torch.zeros((attention_mask.shape[0], attention_mask.shape[1], self.config.parscale_n_tokens, attention_mask.shape[3]), dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask
], dim=2)

if query_states.size(2) != 1:
query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2)

sliding_window = None
if (
self.config.use_sliding_window
Expand Down Expand Up @@ -613,13 +628,13 @@ def forward(

# The trained prefix is saved in layer.self_attn.prefix_k / layer.self_attn.prefix_v
# We extract them to construct ParscaleCache.
if past_key_values is None or past_key_values.get_seq_length() == 0:
if use_cache and (past_key_values is None or past_key_values.get_seq_length() == 0):
past_key_values = ParscaleCache([layer.self_attn.prefix_k for layer in self.layers], [layer.self_attn.prefix_v for layer in self.layers])

if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
if use_cache and cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
Expand Down