From 5d5c26e221192512af7f5925429e3f8f2f7464b5 Mon Sep 17 00:00:00 2001 From: mnachin Date: Fri, 15 May 2026 06:48:26 -0700 Subject: [PATCH] Apply Gemma 4 IT chat template in inference.py and C++ runner Gemma 4 31B-IT is instruction-tuned and produces degenerate output without the chat template wrapping. Auto-wrap --prompt with the IT template (<|turn>user\n{prompt}\n<|turn>model\n <|channel>thought\n) by default; --raw-prompt / --raw_prompt skips wrapping for pre-formatted input. --- examples/models/gemma4_31b/README.md | 6 ++++++ examples/models/gemma4_31b/inference.py | 25 ++++++++++++++++++++++++- examples/models/gemma4_31b/main.cpp | 12 ++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 6f567d739b7..94783c8f823 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -79,6 +79,9 @@ Writes `model.pte` and `model.ptd` into `--output-dir`. ## Eager inference +The prompt is automatically wrapped with the Gemma 4 IT chat template. +Pass `--raw-prompt` to skip template wrapping for pre-formatted input. + ```bash python examples/models/gemma4_31b/inference.py \ --prequantized ./gemma4_31b_int4 \ @@ -109,6 +112,9 @@ The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. ## Run the .pte +The prompt is automatically wrapped with the Gemma 4 IT chat template. +Pass `--raw_prompt` to skip template wrapping for pre-formatted input. + ```bash ./gemma4_31b_runner \ --model_path ./gemma4_31b_exports/model.pte \ diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index 12785450d8c..62dfe5956a7 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -13,6 +13,11 @@ Packs for the target backend (--backend cuda), materializes runtime buffers, optionally compiles with ``torch.compile``, and generates text autoregressively. +Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting. +The ``--prompt`` is automatically wrapped with the Gemma 4 chat template +(``<|turn>user\\n{prompt}\\n<|turn>model\\n<|channel>thought\\n``; BOS is prepended separately). +Pass ``--raw-prompt`` to skip template wrapping (e.g., for pre-formatted input). + Usage: python inference.py \\ --prequantized ./gemma4_31b_int4 \\ @@ -63,6 +68,17 @@ def _move_to_cuda(model, config) -> None: materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda") +def apply_chat_template(prompt: str) -> str: + """Wrap a user prompt in the Gemma 4 IT chat template. + + Does not include BOS — ``generate()`` prepends it at the token-ID level. + """ + return ( + "<|turn>user\n" + prompt + + "\n<|turn>model\n<|channel>thought\n" + ) + + def generate( model, tokenizer, @@ -155,6 +171,11 @@ def main() -> None: default=4096, help="KV cache length to allocate for this run.", ) + parser.add_argument( + "--raw-prompt", + action="store_true", + help="Skip chat-template wrapping (use if the prompt is already formatted).", + ) parser.add_argument( "--no-compile", action="store_true", @@ -204,6 +225,8 @@ def main() -> None: # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). eos_token_ids = {1, 50, 106} + prompt = args.prompt if args.raw_prompt else apply_chat_template(args.prompt) + print(f"\nPrompt: {args.prompt}") print("-" * 40) @@ -211,7 +234,7 @@ def main() -> None: output = generate( model, tokenizer, - args.prompt, + prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, eos_token_ids=eos_token_ids, diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 0be2fef517c..3ddf64e410f 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -65,6 +65,10 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); +DEFINE_bool( + raw_prompt, + false, + "Skip chat-template wrapping (use if the prompt is already formatted)."); DEFINE_bool( cuda_graph, false, @@ -232,6 +236,14 @@ int main(int argc, char** argv) { (std::istreambuf_iterator(f)), std::istreambuf_iterator()); } + // Wrap with Gemma 4 IT chat template unless --raw_prompt is set. + // BOS is prepended separately below; this adds the turn structure and the + // empty thought block required by the instruction-tuned model. + if (!FLAGS_raw_prompt) { + prompt_text = "<|turn>user\n" + prompt_text + + "\n<|turn>model\n<|channel>thought\n"; + } + // Encode prompt auto encode_result = tokenizer->encode(prompt_text); if (!encode_result.ok()) {