From e176e8abd14ccdef31fc6d3c735a8db0771b8fa1 Mon Sep 17 00:00:00 2001 From: The gemma Authors Date: Fri, 22 May 2026 02:55:54 -0700 Subject: [PATCH] Optionally allow models to keep the last cache item at prefill time. PiperOrigin-RevId: 919561155 --- gemma/gm/text/_prefill.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gemma/gm/text/_prefill.py b/gemma/gm/text/_prefill.py index d8cabd2c..46432e68 100644 --- a/gemma/gm/text/_prefill.py +++ b/gemma/gm/text/_prefill.py @@ -217,9 +217,12 @@ def prefill( # A cleaner implementation could be to have a per-batch cache index, to # remove padding. But I leave this to my future self (or to future Gemini). - new_used_cache_length = ( - prev_turns.used_cache_length + input.length_with_mm - 1 - ) + if hasattr(model, 'keep_last_prefill_kv') and model.keep_last_prefill_kv: + new_used_cache_length = prev_turns.used_cache_length + input.length_with_mm + else: + new_used_cache_length = ( + prev_turns.used_cache_length + input.length_with_mm - 1 + ) cache = cache.set_end_index(new_used_cache_length) # TODO(epot): The first token was predicted, so could use this, but would