diff --git a/modeling_qwen2_parscale.py b/modeling_qwen2_parscale.py index 5bc82f5..e81f0c2 100755 --- a/modeling_qwen2_parscale.py +++ b/modeling_qwen2_parscale.py @@ -36,7 +36,6 @@ from .configuration_qwen2_parscale import Qwen2ParScaleConfig from typing import Any, Dict, List, Optional, Tuple, Union - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" @@ -138,6 +137,7 @@ def __init__(self, prefix_k, prefix_v) -> None: self.value_cache: List[torch.Tensor] = prefix_v self.parscale_n = prefix_k[0].size(0) self.n_prefix_tokens = prefix_k[0].size(2) + def update( self, key_states: torch.Tensor, @@ -191,6 +191,7 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[torch.bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -203,30 +204,44 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if self.training: + assert not use_cache, "use_cache must be False during training" + assert past_key_value is None, "past_key_value must be False during training" + + if past_key_value is not None and use_cache: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if self.config.parscale_n > 1: + else: + if self.config.parscale_n > 1: + B = key_states.size(0) + prefix_k = repeat(self.prefix_k, 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.config.parscale_n) + prefix_v = repeat(self.prefix_v, 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.config.parscale_n) + + key_states = torch.cat([prefix_k, key_states], dim=2) + value_states = torch.cat([prefix_v, value_states], dim=2) + if self.config.parscale_n > 1: # Expand attention mask to contain the prefix tokens n_virtual_tokens = self.config.parscale_n_tokens if attention_mask is not None: - attention_mask = torch.cat([ - torch.zeros((attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], self.config.parscale_n_tokens), dtype=attention_mask.dtype, device=attention_mask.device), - attention_mask - ], dim=3) - - if query_states.size(2) != 1: - query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2) - if attention_mask is not None: + if attention_mask.dim() == 2: + # [B, T] -> [B, T + n_virtual_tokens] + B, T = attention_mask.shape + # Create prefix mask: all ones for virtual tokens + prefix_mask = torch.ones(B, n_virtual_tokens, dtype=attention_mask.dtype, device=attention_mask.device) + # Concatenate prefix mask and original mask + attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) + elif attention_mask.dim() == 4: attention_mask = torch.cat([ torch.zeros((attention_mask.shape[0], attention_mask.shape[1], self.config.parscale_n_tokens, attention_mask.shape[3]), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask ], dim=2) + if query_states.size(2) != 1: + query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2) + sliding_window = None if ( self.config.use_sliding_window @@ -613,13 +628,13 @@ def forward( # The trained prefix is saved in layer.self_attn.prefix_k / layer.self_attn.prefix_v # We extract them to construct ParscaleCache. - if past_key_values is None or past_key_values.get_seq_length() == 0: + if use_cache and (past_key_values is None or past_key_values.get_seq_length() == 0): past_key_values = ParscaleCache([layer.self_attn.prefix_k for layer in self.layers], [layer.self_attn.prefix_v for layer in self.layers]) if use_cache and past_key_values is None: past_key_values = DynamicCache() - if cache_position is None: + if use_cache and cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device