Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
82 changes: 82 additions & 0 deletions inference_models/examples/gemma4/count_backpacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""End-to-end Gemma 4 example using ``inference_models.AutoModel`` (no CLI arguments).

Downloads a public sample image, loads a hosted Gemma 4 checkpoint via Roboflow, and
asks a focused counting question.

Run from the ``inference_models`` package root::

uv run python examples/gemma4/run_gemma4_local.py

Optional: set ``GEMMA4_MODEL_ID`` to override the default Roboflow registry id.
"""

from __future__ import annotations

import io
import os
import sys

import numpy as np
import requests
from PIL import Image

from inference_models import AutoModel
from inference_models.configuration import DEFAULT_DEVICE

# Same image used in repo docs (e.g. workflows benchmarks).
IMAGE_URL = "https://media.roboflow.com/inference/people-walking.jpg"

# Roboflow registry id (must match a registered Gemma 4 package).
DEFAULT_MODEL_ID = "gemma-4-e2b-it"

SYSTEM_PROMPT = (
"You are a precise vision assistant. When asked about people or objects in a scene, "
"base your answer only on what is clearly visible. If you are uncertain, say so. "
"For counting questions, give a single best estimate and briefly note any ambiguity "
"(e.g. partially occluded figures or unclear backpacks)."
)

USER_PROMPT = (
"How many people in this image are clearly wearing a backpack? "
"Answer with a number first, then one short sentence explaining what you counted."
)


def _build_prompt(user: str, system: str) -> str:
return f"{user}<system_prompt>{system}"


def _load_image_rgb(url: str) -> np.ndarray:
response = requests.get(url, timeout=60)
response.raise_for_status()
image = Image.open(io.BytesIO(response.content)).convert("RGB")
return np.array(image)


def main() -> None:
load_kw = {
"device": DEFAULT_DEVICE,
"backend": "hugging-face",
}
print(f"Loading hosted model {DEFAULT_MODEL_ID!r} …")
model = AutoModel.from_pretrained(DEFAULT_MODEL_ID, **load_kw)

print(f"Fetching image {IMAGE_URL!r} …")
image_rgb = _load_image_rgb(IMAGE_URL)
prompt = _build_prompt(USER_PROMPT, SYSTEM_PROMPT)

print("Running inference …")
outputs = model.prompt(
images=image_rgb,
prompt=prompt,
input_color_format="rgb",
max_new_tokens=256,
do_sample=False,
)
print("---")
print(outputs[0] if outputs else outputs)


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions inference_models/examples/gemma4/model_config.example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"model_architecture": "gemma-4-e2b-it",
"task_type": "vlm",
"backend_type": "hugging-face"
}
29 changes: 29 additions & 0 deletions inference_models/inference_models/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,35 @@
variable_name="INFERENCE_MODELS_QWEN25_VL_DEFAULT_SKIP_SPECIAL_TOKENS",
default=True,
)
INFERENCE_MODELS_GEMMA4_DEFAULT_MAX_NEW_TOKENS = get_integer_from_env(
variable_name="INFERENCE_MODELS_GEMMA4_DEFAULT_MAX_NEW_TOKENS",
default=512,
)
INFERENCE_MODELS_GEMMA4_DEFAULT_DO_SAMPLE = get_boolean_from_env(
variable_name="INFERENCE_MODELS_GEMMA4_DEFAULT_DO_SAMPLE",
default=INFERENCE_MODELS_DEFAULT_DO_SAMPLE,
)
INFERENCE_MODELS_GEMMA4_DEFAULT_ENABLE_THINKING = get_boolean_from_env(
variable_name="INFERENCE_MODELS_GEMMA4_DEFAULT_ENABLE_THINKING",
default=False,
)
INFERENCE_MODELS_GEMMA4_DEFAULT_SKIP_SPECIAL_TOKENS = get_boolean_from_env(
variable_name="INFERENCE_MODELS_GEMMA4_DEFAULT_SKIP_SPECIAL_TOKENS",
default=True,
)
# Official Gemma 4 sampling recommendations when ``do_sample`` is True (HF model cards).
INFERENCE_MODELS_GEMMA4_DEFAULT_TEMPERATURE = get_float_from_env(
variable_name="INFERENCE_MODELS_GEMMA4_DEFAULT_TEMPERATURE",
default=1.0,
)
INFERENCE_MODELS_GEMMA4_DEFAULT_TOP_P = get_float_from_env(
variable_name="INFERENCE_MODELS_GEMMA4_DEFAULT_TOP_P",
default=0.95,
)
INFERENCE_MODELS_GEMMA4_DEFAULT_TOP_K = get_integer_from_env(
variable_name="INFERENCE_MODELS_GEMMA4_DEFAULT_TOP_K",
default=64,
)
INFERENCE_MODELS_RESNET_DEFAULT_CONFIDENCE = get_float_from_env(
variable_name="INFERENCE_MODELS_RESNET_DEFAULT_CONFIDENCE",
default=INFERENCE_MODELS_DEFAULT_CONFIDENCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,22 @@ class RegistryEntry:
module_name="inference_models.models.qwen3_5.qwen3_5_hf",
class_name="Qwen35HF",
),
("gemma-4-e2b-it", VLM_TASK, BackendType.HF): LazyClass(
module_name="inference_models.models.gemma4.gemma4_hf",
class_name="Gemma4HF",
),
("gemma-4-e4b-it", VLM_TASK, BackendType.HF): LazyClass(
module_name="inference_models.models.gemma4.gemma4_hf",
class_name="Gemma4HF",
),
("gemma-4-31b-it", VLM_TASK, BackendType.HF): LazyClass(
module_name="inference_models.models.gemma4.gemma4_hf",
class_name="Gemma4HF",
),
("gemma-4-26b-a4b-it", VLM_TASK, BackendType.HF): LazyClass(
module_name="inference_models.models.gemma4.gemma4_hf",
class_name="Gemma4HF",
),
("florence-2", VLM_TASK, BackendType.HF): LazyClass(
module_name="inference_models.models.florence2.florence2_hf",
class_name="Florence2HF",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Gemma 4 multimodal (Hugging Face) implementations
Loading
Loading