diff --git a/gemma/gm/text/_prefill.py b/gemma/gm/text/_prefill.py index d8cabd2c..46432e68 100644 --- a/gemma/gm/text/_prefill.py +++ b/gemma/gm/text/_prefill.py @@ -217,9 +217,12 @@ def prefill( # A cleaner implementation could be to have a per-batch cache index, to # remove padding. But I leave this to my future self (or to future Gemini). - new_used_cache_length = ( - prev_turns.used_cache_length + input.length_with_mm - 1 - ) + if hasattr(model, 'keep_last_prefill_kv') and model.keep_last_prefill_kv: + new_used_cache_length = prev_turns.used_cache_length + input.length_with_mm + else: + new_used_cache_length = ( + prev_turns.used_cache_length + input.length_with_mm - 1 + ) cache = cache.set_end_index(new_used_cache_length) # TODO(epot): The first token was predicted, so could use this, but would