From 9daf5e03e69ea7c0643870a7687d0ee88eba3fe7 Mon Sep 17 00:00:00 2001 From: AshNicolus Date: Fri, 17 Oct 2025 22:31:35 +0530 Subject: [PATCH 1/2] Remove TensorFlow dependency from multimodal/image.py Replace tf.image decoding/resizing with Pillow + NumPy + JAX. - Convert float inputs (0..1 or 0..255) to uint8 for processing. - Resize with PIL.BILINEAR, convert back to JAX array. - Normalize and clip outputs to [-1, 1]. - Resolve TODO(eyvinec) to remove TF dependency. --- gemma/multimodal/image.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/gemma/multimodal/image.py b/gemma/multimodal/image.py index 91f1ff4a..ca823f36 100644 --- a/gemma/multimodal/image.py +++ b/gemma/multimodal/image.py @@ -23,7 +23,10 @@ from kauldron import typing import numpy as np from PIL import Image -import tensorflow as tf +import warnings +# Removed TensorFlow dependency: use PIL + NumPy + JAX for decoding/resizing. +# Note: inputs are expected to be image arrays (H,W,C) in uint8 or floats. +# Keep a small fallback warning if the input array isn't a standard type. _IMAGE_MEAN = (127.5,) * 3 _IMAGE_STD = (127.5,) * 3 @@ -69,17 +72,28 @@ def pre_process_image( Returns: The pre-processed image. """ - # all inputs are expected to have been jpeg compressed. - # TODO(eyvinec): we should remove tf dependency. - image = jnp.asarray( - tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3) - ) - image = jax.image.resize( - image, - shape=(image_height, image_width, 3), - method="bilinear", - antialias=True, - ) + # Accept numpy / jax arrays or PIL images. Convert to uint8 ndarray for PIL. + arr = np.asarray(image) + + # If floats in [0, 1], convert to 0-255 uint8 + if np.issubdtype(arr.dtype, np.floating): + if arr.max() <= 1.0: + arr = (arr * 255.0).round().astype(np.uint8) + else: + arr = np.clip(arr, 0, 255).round().astype(np.uint8) + else: + arr = arr.astype(np.uint8) + + # PIL expects shape (W, H) ordering for resize tuple; Image.fromarray handles H,W,C. + pil = Image.fromarray(arr) + # Use bilinear resizing; PIL's LANCZOS is a high-quality downsample filter but + # bilinear better matches previous `jax.image.resize(..., method='bilinear')`. + pil = pil.resize((image_width, image_height), resample=Image.BILINEAR) + + # Back to numpy -> jax + resized = np.asarray(pil).astype(np.float32) + image = jnp.asarray(resized) + image = normalize_images(image) image = jnp.clip(image, -1, 1) return image From 828f9348a0cce2009e413af06c0cd600262e6f66 Mon Sep 17 00:00:00 2001 From: AshNicolus Date: Sat, 23 May 2026 02:06:19 +0530 Subject: [PATCH 2/2] Fix typos in LoRA einsum utils docstring --- gemma/peft/_einsum_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gemma/peft/_einsum_utils.py b/gemma/peft/_einsum_utils.py index 847acb9e..359c4531 100644 --- a/gemma/peft/_einsum_utils.py +++ b/gemma/peft/_einsum_utils.py @@ -28,7 +28,7 @@ def get_lora_einsum_str_and_shapes( ) -> tuple[str, _Shape, _Shape]: """Extract the LoRA decomposition from the original einsum parameters. - This function reqrites a einsum string `inputs,weights->outputs` into + This function rewrites an einsum string `inputs,weights->outputs` into `inputs,a,b->outputs`. Args: @@ -66,7 +66,7 @@ def get_lora_einsum_str_and_shapes( lora_einsum_str = f'{inputs},{a_str},{b_str}->{outputs}' - # This assume there's no elipsis in the weights. + # This assumes there's no ellipsis in the weights. weights_str_to_dim = dict(zip(weights, weights_shape)) weights_str_to_dim[rank_dim] = rank a_shape = tuple(weights_str_to_dim[c] for c in a_str)