From cb3d2ac6f9dbea94ef96a0c392cb2fee4117aee8 Mon Sep 17 00:00:00 2001 From: The gemma Authors Date: Thu, 28 May 2026 17:49:39 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 923073660 --- gemma/gm/ckpts/_checkpoint.py | 15 +++++++++++++-- gemma/gm/nn/gemma4/_transformer.py | 19 +++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/gemma/gm/ckpts/_checkpoint.py b/gemma/gm/ckpts/_checkpoint.py index 9449b9a6..d70de8d6 100644 --- a/gemma/gm/ckpts/_checkpoint.py +++ b/gemma/gm/ckpts/_checkpoint.py @@ -271,8 +271,19 @@ def load_params( # To supports different checkpoint structures, the original params have to # be remapped into the checkpoint structure. output_with_skip = metadata.make_tree_for_params(params) - restore_fn = functools.partial(ckpt.restore, path) - output = _partial_restore(restore_fn, output_with_skip) + def restore_fn(tree): + return ckpt.restore(path, tree) + + # Restore EVERYTHING from checkpoint using metadata.tree as target + restored_tree = _partial_restore(restore_fn, metadata.tree) + + def update_tree(target, source): + if isinstance(target, dict) and isinstance(source, dict): + return {k: update_tree(target.get(k), source[k]) if k in source else v for k, v in target.items()} + return source + + # Copy restored values into output_with_skip + output = update_tree(output_with_skip, restored_tree) # TODO(epot): Better API. Currently this do not quantize the weights, but # just refactor the params to the QAT structure. diff --git a/gemma/gm/nn/gemma4/_transformer.py b/gemma/gm/nn/gemma4/_transformer.py index ad7927f2..1bbeec28 100644 --- a/gemma/gm/nn/gemma4/_transformer.py +++ b/gemma/gm/nn/gemma4/_transformer.py @@ -546,7 +546,8 @@ def _encode_vision(self, vision_input: PreprocessedVisionInput): n_images = len(vision_input.soft_token_counts) patches = vision_input.patches positions_xy = vision_input.positions_xy - max_patches = patches.shape[1] // n_images + B = patches.shape[0] + max_patches = patches.shape[1] // (n_images // B) patches = jnp.reshape(patches, (n_images, max_patches, patches.shape[2])) positions_xy = jnp.reshape( @@ -567,9 +568,19 @@ def _encode_vision(self, vision_input: PreprocessedVisionInput): real_tokens = embeddings[i][:expected_count] per_image_tokens.append(real_tokens) - all_tokens = jnp.concatenate(per_image_tokens, axis=0) - all_tokens = self.embedder.encode_vision(all_tokens[None, None, :, :]) - all_tokens = all_tokens[:, 0, :, :] + # Group per_image_tokens by batch element B to preserve the batch dimension + B = patches.shape[0] + n_images_per_example = n_images // B + batched_tokens = [] + for b in range(B): + idx_start = b * n_images_per_example + idx_end = idx_start + n_images_per_example + example_tokens = jnp.concatenate(per_image_tokens[idx_start:idx_end], axis=0) + batched_tokens.append(example_tokens) + all_tokens = jnp.stack(batched_tokens, axis=0) # Shape [B, total_tokens_per_example, dim] + + # Project vision embeddings, preserving the batch dimension B + all_tokens = self.embedder.encode_vision(all_tokens) return all_tokens def _encode_audio(self, audio, audio_lengths, audio_soft_token_counts):