diff --git a/gemma/gm/nn/vision/_image.py b/gemma/gm/nn/vision/_image.py index ce4182e0..a357c14e 100644 --- a/gemma/gm/nn/vision/_image.py +++ b/gemma/gm/nn/vision/_image.py @@ -23,7 +23,10 @@ from kauldron import typing import numpy as np from PIL import Image -import tensorflow as tf +import warnings +# Removed TensorFlow dependency: use PIL + NumPy + JAX for decoding/resizing. +# Note: inputs are expected to be image arrays (H,W,C) in uint8 or floats. +# Keep a small fallback warning if the input array isn't a standard type. _IMAGE_MEAN = (127.5,) * 3 _IMAGE_STD = (127.5,) * 3 @@ -69,17 +72,28 @@ def pre_process_image( Returns: The pre-processed image. """ - # all inputs are expected to have been jpeg compressed. - # TODO(eyvinec): we should remove tf dependency. - image = jnp.asarray( - tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3) - ) - image = jax.image.resize( - image, - shape=(image_height, image_width, 3), - method="bilinear", - antialias=True, - ) + # Accept numpy / jax arrays or PIL images. Convert to uint8 ndarray for PIL. + arr = np.asarray(image) + + # If floats in [0, 1], convert to 0-255 uint8 + if np.issubdtype(arr.dtype, np.floating): + if arr.max() <= 1.0: + arr = (arr * 255.0).round().astype(np.uint8) + else: + arr = np.clip(arr, 0, 255).round().astype(np.uint8) + else: + arr = arr.astype(np.uint8) + + # PIL expects shape (W, H) ordering for resize tuple; Image.fromarray handles H,W,C. + pil = Image.fromarray(arr) + # Use bilinear resizing; PIL's LANCZOS is a high-quality downsample filter but + # bilinear better matches previous `jax.image.resize(..., method='bilinear')`. + pil = pil.resize((image_width, image_height), resample=Image.BILINEAR) + + # Back to numpy -> jax + resized = np.asarray(pil).astype(np.float32) + image = jnp.asarray(resized) + image = normalize_images(image) image = jnp.clip(image, -1, 1) return image diff --git a/gemma/peft/_einsum_utils.py b/gemma/peft/_einsum_utils.py index 847acb9e..359c4531 100644 --- a/gemma/peft/_einsum_utils.py +++ b/gemma/peft/_einsum_utils.py @@ -28,7 +28,7 @@ def get_lora_einsum_str_and_shapes( ) -> tuple[str, _Shape, _Shape]: """Extract the LoRA decomposition from the original einsum parameters. - This function reqrites a einsum string `inputs,weights->outputs` into + This function rewrites an einsum string `inputs,weights->outputs` into `inputs,a,b->outputs`. Args: @@ -66,7 +66,7 @@ def get_lora_einsum_str_and_shapes( lora_einsum_str = f'{inputs},{a_str},{b_str}->{outputs}' - # This assume there's no elipsis in the weights. + # This assumes there's no ellipsis in the weights. weights_str_to_dim = dict(zip(weights, weights_shape)) weights_str_to_dim[rank_dim] = rank a_shape = tuple(weights_str_to_dim[c] for c in a_str)