Skip to content

[GRPO/RLOO] Tokenize before vLLM generation call#5238

Merged
qgallouedec merged 38 commits intomainfrom
vllm-generate-with-token-ids
Mar 10, 2026
Merged

[GRPO/RLOO] Tokenize before vLLM generation call#5238
qgallouedec merged 38 commits intomainfrom
vllm-generate-with-token-ids

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Mar 7, 2026

Context

Part of the series to fix the re-tokenization bug in GRPO multi-turn tool calling (see #5224).

When the model generates a completion in a tool-calling loop, the decoded text is re-tokenized via apply_chat_template, which can produce different token IDs due to BPE merge ambiguities. To fix this, we need a token-in / token-out pipeline: tokenize once, then pass raw token IDs through every subsequent generation call — never decoding and re-tokenizing.

This PR moves tokenization out of VLLMGeneration.generate and into the trainers' _generate_single_turn, so that vLLM always receives raw token IDs instead of text or chat messages.

Changes

  • VLLMGeneration.generate(): Replace the prompts (text/messages) parameter with prompts (token ID lists) + images (optional PIL image lists). Remove the internal chat() / text-tokenization paths — the method now always forwards pre-tokenized IDs to vLLM. For colocate mode, build {"prompt_token_ids": ids, "multi_modal_data": ...} dicts. Remove unused imports (json, is_conversational, prepare_multimodal_messages_vllm).
  • GRPO _generate_single_turn(): Tokenize prompts before calling vLLM. For conversational prompts, use apply_chat_template(tokenize=True) and extract PIL images from message content blocks. For text prompts, use processing_class(text=prompts). Pass the resulting token IDs and images to vllm_generation.generate().
  • RLOO _generate_single_turn(): Same tokenization pattern applied.

Why

Previously, VLLMGeneration.generate received raw text or chat messages and was responsible for tokenization (server mode) or delegating it to vLLM's chat() (colocate mode). This made it impossible to guarantee token-level consistency between the trainer's tokenization and vLLM's — a prerequisite for the token-in/token-out pipeline.

By tokenizing once in _generate_single_turn and passing only token IDs downstream, we ensure that:

  1. The trainer controls tokenization (chat template, tools, special tokens) in one place.
  2. vLLM receives exactly the token IDs the trainer expects, eliminating BPE ambiguity.
  3. The VLM case is handled correctly: images are extracted from messages and passed separately via vLLM's multi_modal_data mechanism.

Backward compatibility

The VLLMGeneration.generate signature changes from accepting text/messages to accepting token ID lists. This is an internal API — all call sites (GRPO and RLOO trainers) are updated in this PR. No user-facing API changes.


Note

Medium Risk
Touches core generation paths for GRPO/RLOO and changes vLLM input semantics, so regressions would affect sampling behavior and distributed generation (especially multimodal/image batches).

Overview
Implements a token-in/token-out vLLM generation path for GRPOTrainer and RLOOTrainer by moving all prompt tokenization out of VLLMGeneration.generate() and into the trainers’ _generate_single_turn().

VLLMGeneration.generate() now always receives pre-tokenized prompt_token_ids (plus optional per-prompt image lists for VLMs), removes the chat/template/tool-handling branches and JSON tool-arg coercion, and adds distributed-safe gathering of images to avoid collectives deadlocking when some ranks have no images. Colocate mode now builds vLLM prompt dicts (prompt_token_ids + multi_modal_data) before calling llm.generate().

Written by Cursor Bugbot for commit fee553d. This will update automatically on new commits. Configure here.

@qgallouedec qgallouedec changed the title Move tokenization before vLLM generation call [GRPO/RLOO] Tokenize before vLLM generation call Mar 7, 2026
@qgallouedec
Copy link
Member Author

qgallouedec commented Mar 7, 2026

Before vs After

# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("BlackBeenie/simple-math", split="train")

dataset = dataset.filter(lambda x: x["level"] == "Level 1")

def extract_solution(example):
    return {"solution": example["reward_model"]["ground_truth"]}

dataset = dataset.map(extract_solution)

training_args = GRPOConfig(
    output_dir="/tmp/grpo_test",
    use_vllm=True,
    vllm_mode="colocate",
    max_steps=50,
    logging_steps=1,
    report_to="trackio",
)

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
    args=training_args,
)
trainer.train()

https://qgallouedec-trackio.hf.space?project=huggingface&runs=qgallouedec-1772852549,qgallouedec-1772852661&sidebar=hidden&navbar=hidden

everything is exactly the same

chart-161 chart-160 chart-159 chart-158 chart-157

@qgallouedec
Copy link
Member Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 09128d6711

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Base automatically changed from move-rollout-func to main March 10, 2026 00:03
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Some suggestion below.

In relation with tests (I know this is always a non-trivial issue), the old prepare_multimodal_messages_vllm path in colocate mode (calling llm.chat()) is replaced by the multi_modal_data: {"image": ...} dict path with llm.generate(). Do you think there would be worth adding a unit test for this path (maybe mocking llm)?

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

qgallouedec and others added 4 commits March 10, 2026 18:14
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
@qgallouedec qgallouedec merged commit b77f36f into main Mar 10, 2026
14 checks passed
@qgallouedec qgallouedec deleted the vllm-generate-with-token-ids branch March 10, 2026 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants