Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion gemma/gm/nn/gemma4/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __call__(
cache: LayerCache | None,
attn_mask: jax.Array,
kv_shared_cache: LayerCache | None = None,
skip_sliding_mask: bool = False,
) -> tuple[LayerCache | None, jax.Array]:
"""Applies multi-head attention to the inputs.

Expand All @@ -255,6 +256,7 @@ def __call__(
cache: KV cache or None.
attn_mask: Attention mask of shape [batch_size, seq_len, cache_size].
kv_shared_cache: Cache for shared KV layers.
skip_sliding_mask: If True, skip the sliding mask.

Returns:
cache: Updated attention KV cache.
Expand Down Expand Up @@ -335,7 +337,7 @@ def __call__(
logits = jnp.tanh(logits / self.attn_logits_soft_cap)
logits = logits * self.attn_logits_soft_cap

if self.attn_type == AttentionType.LOCAL_SLIDING:
if self.attn_type == AttentionType.LOCAL_SLIDING and not skip_sliding_mask:
if self.sliding_window_size is None:
raise ValueError(
'Sliding_window_size must be set if Local Sliding attention type'
Expand Down Expand Up @@ -596,6 +598,7 @@ def __call__(
attn_mask: jax.Array,
per_layer_input: jax.Array | None = None,
kv_shared_cache: LayerCache | None = None,
skip_sliding_mask: bool = False,
) -> tuple[LayerCache | None, jax.Array]:
"""Applies the block to the inputs.

Expand All @@ -607,6 +610,7 @@ def __call__(
per_layer_input: Per-layer input of shape [batch_size, seq_len,
per_layer_input_dim].
kv_shared_cache: Cache for shared KV layers.
skip_sliding_mask: If True, skip the sliding mask.

Returns:
cache: Updated attention KV cache.
Expand All @@ -621,6 +625,7 @@ def __call__(
cache,
attn_mask,
kv_shared_cache,
skip_sliding_mask=skip_sliding_mask,
)

if self.post_attention_norm is not None:
Expand Down
3 changes: 3 additions & 0 deletions gemma/gm/nn/gemma4/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,13 @@ def _apply_attention(
kv_shared_cache = None
# Select the appropriate attention mask for this layer type.
attn_mask = inputs.attention_mask
skip_sliding_mask = False
if (
inputs.sliding_attention_mask is not None
and block.attn_type == _modules.AttentionType.LOCAL_SLIDING
):
attn_mask = inputs.sliding_attention_mask
skip_sliding_mask = True
layer_cache, x = block(
x,
inputs.positions,
Expand All @@ -380,6 +382,7 @@ def _apply_attention(
if self.config.per_layer_input_dim
else None,
kv_shared_cache=kv_shared_cache,
skip_sliding_mask=skip_sliding_mask,
)
new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch

Expand Down
Loading