From 745f1398338a37b1237ab38575f9c8697ca53b58 Mon Sep 17 00:00:00 2001 From: stanley1208 Date: Mon, 23 Feb 2026 22:28:50 -0800 Subject: [PATCH] refactor: use jnp.take_along_axis for return_last_only indexing Replace manual fancy indexing with jnp.take_along_axis in Transformer and Gemma3nTransformer, as requested by TODO(epot). Cleaner and avoids constructing an arange index array. --- gemma/gm/nn/_transformer.py | 6 ++++-- gemma/gm/nn/gemma3n/_transformer.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gemma/gm/nn/_transformer.py b/gemma/gm/nn/_transformer.py index 6248e940..2f99e903 100644 --- a/gemma/gm/nn/_transformer.py +++ b/gemma/gm/nn/_transformer.py @@ -250,8 +250,10 @@ def __call__( # pytype: disable=signature-mismatch if return_last_only: last_input_token_idx = jnp.sum(inputs.inputs_mask, axis=-1) - 1 - # TODO(epot): Use `jnp.take_along_axis` - x = x[jnp.arange(len(x)), last_input_token_idx, ...] + x = jnp.take_along_axis( + x, last_input_token_idx[:, None, None], axis=1 + ) + x = jnp.squeeze(x, axis=1) elif images is not None: # Remove the MM extra tokens inserted. # During fine-tuning, the prompt is always masked, and the model cannot diff --git a/gemma/gm/nn/gemma3n/_transformer.py b/gemma/gm/nn/gemma3n/_transformer.py index 35b41aaa..b64b59e3 100644 --- a/gemma/gm/nn/gemma3n/_transformer.py +++ b/gemma/gm/nn/gemma3n/_transformer.py @@ -320,8 +320,10 @@ def __call__( # pytype: disable=signature-mismatch if return_last_only: last_input_token_idx = jnp.sum(inputs.inputs_mask, axis=-1) - 1 - # TODO(epot): Use `jnp.take_along_axis` - x = x[jnp.arange(len(x)), last_input_token_idx, ...] + x = jnp.take_along_axis( + x, last_input_token_idx[:, None, None], axis=1 + ) + x = jnp.squeeze(x, axis=1) elif images is not None: # Remove the MM extra tokens inserted. # During fine-tuning, the prompt is always masked, and the model cannot