Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/models/gemma4_31b/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ Writes `model.pte` and `model.ptd` into `--output-dir`.

## Eager inference

The prompt is automatically wrapped with the Gemma 4 IT chat template.
Pass `--raw-prompt` to skip template wrapping for pre-formatted input.

```bash
python examples/models/gemma4_31b/inference.py \
--prequantized ./gemma4_31b_int4 \
Expand Down Expand Up @@ -109,6 +112,9 @@ The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.

## Run the .pte

The prompt is automatically wrapped with the Gemma 4 IT chat template.
Pass `--raw_prompt` to skip template wrapping for pre-formatted input.

```bash
./gemma4_31b_runner \
--model_path ./gemma4_31b_exports/model.pte \
Expand Down
25 changes: 24 additions & 1 deletion examples/models/gemma4_31b/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -13,6 +13,11 @@
Packs for the target backend (--backend cuda), materializes runtime buffers,
optionally compiles with ``torch.compile``, and generates text autoregressively.

Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting.
The ``--prompt`` is automatically wrapped with the Gemma 4 chat template
(``<|turn>user\\n{prompt}<turn|>\\n<|turn>model\\n<|channel>thought\\n<channel|>``; BOS is prepended separately).
Pass ``--raw-prompt`` to skip template wrapping (e.g., for pre-formatted input).

Usage:
python inference.py \\
--prequantized ./gemma4_31b_int4 \\
Expand Down Expand Up @@ -63,6 +68,17 @@
materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda")


def apply_chat_template(prompt: str) -> str:
"""Wrap a user prompt in the Gemma 4 IT chat template.

Does not include BOS — ``generate()`` prepends it at the token-ID level.
"""
return (
"<|turn>user\n" + prompt
+ "<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"
)


def generate(
model,
tokenizer,
Expand Down Expand Up @@ -155,6 +171,11 @@
default=4096,
help="KV cache length to allocate for this run.",
)
parser.add_argument(
"--raw-prompt",
action="store_true",
help="Skip chat-template wrapping (use if the prompt is already formatted).",
)
parser.add_argument(
"--no-compile",
action="store_true",
Expand Down Expand Up @@ -204,14 +225,16 @@
# Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106).
eos_token_ids = {1, 50, 106}

prompt = args.prompt if args.raw_prompt else apply_chat_template(args.prompt)

print(f"\nPrompt: {args.prompt}")
print("-" * 40)

t0 = time.perf_counter()
output = generate(
model,
tokenizer,
args.prompt,
prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
eos_token_ids=eos_token_ids,
Expand Down
12 changes: 12 additions & 0 deletions examples/models/gemma4_31b/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2).");
DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
DEFINE_bool(
raw_prompt,
false,
"Skip chat-template wrapping (use if the prompt is already formatted).");
DEFINE_bool(
cuda_graph,
false,
Expand Down Expand Up @@ -232,6 +236,14 @@ int main(int argc, char** argv) {
(std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
}

// Wrap with Gemma 4 IT chat template unless --raw_prompt is set.
// BOS is prepended separately below; this adds the turn structure and the
// empty thought block required by the instruction-tuned model.
if (!FLAGS_raw_prompt) {
prompt_text = "<|turn>user\n" + prompt_text +
"<turn|>\n<|turn>model\n<|channel>thought\n<channel|>";
}

// Encode prompt
auto encode_result = tokenizer->encode(prompt_text);
if (!encode_result.ok()) {
Expand Down
Loading