diff --git a/gemma/gm/nn/_transformer.py b/gemma/gm/nn/_transformer.py index 6248e940..2f99e903 100644 --- a/gemma/gm/nn/_transformer.py +++ b/gemma/gm/nn/_transformer.py @@ -250,8 +250,10 @@ def __call__( # pytype: disable=signature-mismatch if return_last_only: last_input_token_idx = jnp.sum(inputs.inputs_mask, axis=-1) - 1 - # TODO(epot): Use `jnp.take_along_axis` - x = x[jnp.arange(len(x)), last_input_token_idx, ...] + x = jnp.take_along_axis( + x, last_input_token_idx[:, None, None], axis=1 + ) + x = jnp.squeeze(x, axis=1) elif images is not None: # Remove the MM extra tokens inserted. # During fine-tuning, the prompt is always masked, and the model cannot diff --git a/gemma/gm/nn/gemma3n/_transformer.py b/gemma/gm/nn/gemma3n/_transformer.py index 35b41aaa..b64b59e3 100644 --- a/gemma/gm/nn/gemma3n/_transformer.py +++ b/gemma/gm/nn/gemma3n/_transformer.py @@ -320,8 +320,10 @@ def __call__( # pytype: disable=signature-mismatch if return_last_only: last_input_token_idx = jnp.sum(inputs.inputs_mask, axis=-1) - 1 - # TODO(epot): Use `jnp.take_along_axis` - x = x[jnp.arange(len(x)), last_input_token_idx, ...] + x = jnp.take_along_axis( + x, last_input_token_idx[:, None, None], axis=1 + ) + x = jnp.squeeze(x, axis=1) elif images is not None: # Remove the MM extra tokens inserted. # During fine-tuning, the prompt is always masked, and the model cannot