Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions gemma/gm/nn/vision/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from kauldron import typing
import numpy as np
from PIL import Image
import tensorflow as tf
import warnings
# Removed TensorFlow dependency: use PIL + NumPy + JAX for decoding/resizing.
# Note: inputs are expected to be image arrays (H,W,C) in uint8 or floats.
# Keep a small fallback warning if the input array isn't a standard type.

_IMAGE_MEAN = (127.5,) * 3
_IMAGE_STD = (127.5,) * 3
Expand Down Expand Up @@ -69,17 +72,28 @@ def pre_process_image(
Returns:
The pre-processed image.
"""
# all inputs are expected to have been jpeg compressed.
# TODO(eyvinec): we should remove tf dependency.
image = jnp.asarray(
tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3)
)
image = jax.image.resize(
image,
shape=(image_height, image_width, 3),
method="bilinear",
antialias=True,
)
# Accept numpy / jax arrays or PIL images. Convert to uint8 ndarray for PIL.
arr = np.asarray(image)

# If floats in [0, 1], convert to 0-255 uint8
if np.issubdtype(arr.dtype, np.floating):
if arr.max() <= 1.0:
arr = (arr * 255.0).round().astype(np.uint8)
else:
arr = np.clip(arr, 0, 255).round().astype(np.uint8)
else:
arr = arr.astype(np.uint8)

# PIL expects shape (W, H) ordering for resize tuple; Image.fromarray handles H,W,C.
pil = Image.fromarray(arr)
# Use bilinear resizing; PIL's LANCZOS is a high-quality downsample filter but
# bilinear better matches previous `jax.image.resize(..., method='bilinear')`.
pil = pil.resize((image_width, image_height), resample=Image.BILINEAR)

# Back to numpy -> jax
resized = np.asarray(pil).astype(np.float32)
image = jnp.asarray(resized)

image = normalize_images(image)
image = jnp.clip(image, -1, 1)
return image
Expand Down
4 changes: 2 additions & 2 deletions gemma/peft/_einsum_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_lora_einsum_str_and_shapes(
) -> tuple[str, _Shape, _Shape]:
"""Extract the LoRA decomposition from the original einsum parameters.

This function reqrites a einsum string `inputs,weights->outputs` into
This function rewrites an einsum string `inputs,weights->outputs` into
`inputs,a,b->outputs`.

Args:
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_lora_einsum_str_and_shapes(

lora_einsum_str = f'{inputs},{a_str},{b_str}->{outputs}'

# This assume there's no elipsis in the weights.
# This assumes there's no ellipsis in the weights.
weights_str_to_dim = dict(zip(weights, weights_shape))
weights_str_to_dim[rank_dim] = rank
a_shape = tuple(weights_str_to_dim[c] for c in a_str)
Expand Down