diff --git a/gemma/gm/nn/gemma4/_modules.py b/gemma/gm/nn/gemma4/_modules.py index 2975f330..cb8cf375 100644 --- a/gemma/gm/nn/gemma4/_modules.py +++ b/gemma/gm/nn/gemma4/_modules.py @@ -246,6 +246,7 @@ def __call__( cache: LayerCache | None, attn_mask: jax.Array, kv_shared_cache: LayerCache | None = None, + skip_sliding_mask: bool = False, ) -> tuple[LayerCache | None, jax.Array]: """Applies multi-head attention to the inputs. @@ -255,6 +256,7 @@ def __call__( cache: KV cache or None. attn_mask: Attention mask of shape [batch_size, seq_len, cache_size]. kv_shared_cache: Cache for shared KV layers. + skip_sliding_mask: If True, skip the sliding mask. Returns: cache: Updated attention KV cache. @@ -335,7 +337,7 @@ def __call__( logits = jnp.tanh(logits / self.attn_logits_soft_cap) logits = logits * self.attn_logits_soft_cap - if self.attn_type == AttentionType.LOCAL_SLIDING: + if self.attn_type == AttentionType.LOCAL_SLIDING and not skip_sliding_mask: if self.sliding_window_size is None: raise ValueError( 'Sliding_window_size must be set if Local Sliding attention type' @@ -596,6 +598,7 @@ def __call__( attn_mask: jax.Array, per_layer_input: jax.Array | None = None, kv_shared_cache: LayerCache | None = None, + skip_sliding_mask: bool = False, ) -> tuple[LayerCache | None, jax.Array]: """Applies the block to the inputs. @@ -607,6 +610,7 @@ def __call__( per_layer_input: Per-layer input of shape [batch_size, seq_len, per_layer_input_dim]. kv_shared_cache: Cache for shared KV layers. + skip_sliding_mask: If True, skip the sliding mask. Returns: cache: Updated attention KV cache. @@ -621,6 +625,7 @@ def __call__( cache, attn_mask, kv_shared_cache, + skip_sliding_mask=skip_sliding_mask, ) if self.post_attention_norm is not None: diff --git a/gemma/gm/nn/gemma4/_transformer.py b/gemma/gm/nn/gemma4/_transformer.py index 95e72181..ad7927f2 100644 --- a/gemma/gm/nn/gemma4/_transformer.py +++ b/gemma/gm/nn/gemma4/_transformer.py @@ -366,11 +366,13 @@ def _apply_attention( kv_shared_cache = None # Select the appropriate attention mask for this layer type. attn_mask = inputs.attention_mask + skip_sliding_mask = False if ( inputs.sliding_attention_mask is not None and block.attn_type == _modules.AttentionType.LOCAL_SLIDING ): attn_mask = inputs.sliding_attention_mask + skip_sliding_mask = True layer_cache, x = block( x, inputs.positions, @@ -380,6 +382,7 @@ def _apply_attention( if self.config.per_layer_input_dim else None, kv_shared_cache=kv_shared_cache, + skip_sliding_mask=skip_sliding_mask, ) new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch