Skip to content

Fix jaxtyping shape mismatch for multimodal inputs in Gemma3nTransformer#624

Open
YADAV1825 wants to merge 1 commit into
google-deepmind:mainfrom
YADAV1825:fix-gemma3n-vision-shapes
Open

Fix jaxtyping shape mismatch for multimodal inputs in Gemma3nTransformer#624
YADAV1825 wants to merge 1 commit into
google-deepmind:mainfrom
YADAV1825:fix-gemma3n-vision-shapes

Conversation

@YADAV1825
Copy link
Copy Markdown

Resolves #620

Bug Description:
When running Gemma3n_E2B with multimodal inputs, the gm.text.Sampler properly expands the sequence length to account for the generated image placeholder tokens. However, the jaxtyping annotations in Gemma3nTransformer.call and _encode_and_get_inputs strictly enforced L and L_no_mm for the positions and attention_mask arguments. This caused a TypeCheckError during JAX tracing because the expanded tensors (e.g., length 512) did not match the raw token length (e.g., length 253).

The Fix:

Signature Update: Changed the type hints in call and _encode_and_get_inputs to use L_with_mm for positions and attention_mask to safely permit the expanded shapes generated by the sampler, while maintaining backward compatibility for text-only inputs.

Defensive Fallback: Added internal shape-checking inside _encode_and_get_inputs. If a user bypasses the sampler and manually passes raw unexpanded positions or attention_mask alongside an image, the function catches the shape mismatch and dynamically rebuilds them using _pos_utils.build_positions_from_mask to prevent downstream execution failures.

Testing:

Verified the fix locally by bypassing checkpoint auth and directly passing a dummy image through the sampler and model initialization to test the tensor shapes end-to-end.

from gemma import gm
import jax
import jax.numpy as jnp
from PIL import Image
import numpy as np

def test_gemma_3n_vision_final():
    print("1. Initializing tokenizer and model...")
    tokenizer = gm.text.Gemma3nTokenizer()
    model = gm.nn.Gemma3n_E2B()

    print("2. Creating dummy image to trigger multimodal token expansion...")
    # Create a blank image to test the tensor shapes (height=427, width=640, channels=3)
    img = Image.fromarray(np.zeros((427, 640, 3), dtype=np.uint8))

    print("3. Bypassing checkpoint auth by generating dummy inputs and parameters...")
    # Temporarily initialize sampler without params just to format the inputs correctly
    temp_sampler = gm.text.Sampler(model=model, params={}, tokenizer=tokenizer)
    prompt = "<image>\nAnswer the following question in a single word or short phrase based on the image."
    
    # Extract the properly shaped arrays using the internal sampler method
    inputs = temp_sampler._get_inputs(
        prompt=prompt, 
        images=[img], 
        add_bos=True, 
        has_batch_dim=False, 
        sharding=None
    )
    
    # Initialize Flax model parameters from scratch (creates the missing vision_encoder weights)
    rng = jax.random.PRNGKey(0)
    variables = model.init(
        rng, 
        tokens=inputs.text, 
        images=inputs.images, 
        positions=inputs.positions,
        attention_mask=inputs.attention_mask
    )
    params = variables['params']

    print("4. Setting up the actual sampler with dummy weights...")
    sampler = gm.text.Sampler(model=model, params=params, tokenizer=tokenizer)

    print("5. Running the multimodal forward pass...")
    try:
        # Sample a single token to verify the forward pass works end-to-end without shape errors
        output = sampler.sample(prompt, images=[img], max_new_tokens=1)
        print("\n--- Output ---")
        print(output)
        print("\n MULTIMODAL TEST PASSED! The TypeCheckError is gone and the model executes with images.")
    except Exception as e:
        print("\n MULTIMODAL TEST FAILED. An error occurred:")
        raise e

if __name__ == "__main__":
    test_gemma_3n_vision_final()

Issue Reproduction:

image

Resolved Execution:

1. Initializing tokenizer and model...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1776425962.582935   55376 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1776425964.610381   55376 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2. Creating dummy image to trigger multimodal token expansion...
3. Bypassing checkpoint auth by generating dummy inputs and parameters...
W0000 00:00:1776425967.223325   55376 google_auth_provider.cc:196] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "NOT_FOUND: Error executing an HTTP request: HTTP response code 404 with body '{"error":"invalid_request","error_description":"Service account not enabled on this instance"}'".
/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/jax/_src/ops/scatter.py:103: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion=standard. In future JAX releases this will result in an error.
  warnings.warn(
E0417 11:45:54.825170   55858 slow_operation_alarm.cc:73] 
********************************
[Compiling module jit___call for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
E0417 11:48:19.502198   55376 slow_operation_alarm.cc:140] The operation took 4m24.677172721s

********************************
[Compiling module jit___call for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
4. Setting up the actual sampler with dummy weights...
5. Running the multimodal forward pass...

--- Output ---


 MULTIMODAL TEST PASSED! The TypeCheckError is gone and the model executes with images.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Using Gemma3N_E2B with a prompt with an image fails with TypeCheckError

1 participant