Skip to content

feat: --no-cache-thoughts for reasoning-model multi-turn cache alignment#29

Draft
DavidBellamy wants to merge 9 commits into
prodfrom
feat/no-cache-thoughts-llm360-fork
Draft

feat: --no-cache-thoughts for reasoning-model multi-turn cache alignment#29
DavidBellamy wants to merge 9 commits into
prodfrom
feat/no-cache-thoughts-llm360-fork

Conversation

@DavidBellamy
Copy link
Copy Markdown
Collaborator

Summary

Adds a --no-cache-thoughts server flag that, for reasoning requests, inserts the answer slice into the prefix cache at the original RoPE positions (skipping the thought slice) instead of caching the entire generated trajectory at contiguous positions. Lets a TITO multi-turn rollout reuse turn N's answer KV as a prefix match on turn N+1, without polluting the radix tree with request-specific reasoning content.

End-to-end validated on bbq-8b-mid3 in the agentic-rl image:

  • Baseline server (no flag) — turn 2 cached_tokens = 22 (prompt prefix only)
  • --no-cache-thoughts — turn 2 cached_tokens = 191 (prompt + answer, delta = answer length)

How it works (server side)

Stage Where What
Detect </think> Req.update_reasoning_tokens (schedule_batch.py) Sets req.answer_start_position when the reasoning parser emits its end token
Build the split split_kv_for_no_cache_thoughts (common.py) Returns (virtual_token_ids, virtual_kv_indices, virtual_positions, thought_kv_indices_to_free) — input prompt + post-</think> answer, with answer carrying its original (non-contiguous) positions; thought slots earmarked for free
Route at finish release_kv_cache (common.py) When the flag is on and the request has reasoning, builds the split and calls cache_finished_req(split=…)
Insert + free RadixCache.cache_finished_req (radix_cache.py) Inserts virtual token ids into the radix with original_positions attached, frees the thought slice directly
Round-trip positions RadixCache.{insert,_split_node,_match_prefix_helper,match_prefix}, TreeNode.positions, InsertParams.original_positions, MatchResult.original_positions Per-token position metadata travels through the tree; match_prefix returns it on cache hits
Plumb to Req Req.init_next_round_input (schedule_batch.py) Stores match_result.original_positions on req.cached_positions
Aggregate per batch collect_cached_positions, prepare_for_extend ScheduleBatch.cached_positions_per_req + ScheduleBatch.position_offsets populated
Forward batch positions build_extend_positions, clamp_position (forward_batch_info.py) Prefill and decode positions both honor non-contiguous cached positions; new tokens continue from max(cached) + 1
Backend stubs ChunkCache, SWARadixCache, MambaRadixCache, RadixCacheCpp, LMCRadixCache, SessionAwareCache Accept the new split= kwarg; non-RadixCache backends log a one-time warning and fall back to standard cache_finished_req (graceful no-op)

Integration spec for callers (TITO multi-turn rollouts)

The chat-completions text path cannot deliver multi-turn answer reuse even with this flag, because the chat template re-tokenizes prior assistant content via BPE and the first answer token ID after the assistant header doesn't match the model's emitted ID (leading-whitespace BPE merges differ). Callers must use the TITO protocol:

  1. Turn 1: send via /v1/chat/completions with return_prompt_token_ids=True, return_completion_token_ids=True, return_meta_info=True. Capture from the response:

    • choices[0].prompt_token_ids — turn 1's input IDs
    • choices[0].completion_token_ids — turn 1's output IDs (full, including thoughts)
    • usage.reasoning_tokens — count of thought tokens (including </think>)
  2. Strip the thoughts before appending to the running buffer:
    ```python
    answer_ids = completion_token_ids[reasoning_tokens:]
    running_ids = prompt_token_ids + answer_ids
    ```
    The buffer must NOT contain the thought slice; the cached entry doesn't.

  3. Turn N+1: construct input_ids = running_ids + tokenize(env_delta) where env_delta is the new user message + new assistant priming, including the K2V3 boundary \n (model stops on <|im_end|> but the chat template emits <|im_end|>\n, so the delta starts with \n). Send via /v1/chat/completions with input_ids set in the request body (chat template is bypassed when input_ids is provided).

  4. Verify cache hits: turn N+1's usage.prompt_tokens_details.cached_tokens should grow turn-over-turn by the prior answer's length.

For RL360: this is the same tito-patched flow already used by miles/rollout/generate_hub/agentic_tool_call.py; no client changes needed beyond ensuring the thought-stripping step runs for --no-cache-thoughts-enabled servers.

Tests

  • 34 unit tests under test/registered/radix_cache/ (1 env-skipped for C++ extension). Covers: CLI flag, Req.answer_start_position, split helper, radix position round-trip, release_kv_cache routing, compute_position/clamp_position non-contiguous support, derive helpers, Req.cached_positions, ScheduleBatch.cached_positions_per_req, and the 6 non-RadixCache backend signature acceptance.
  • TITO E2E at test/manual/test_no_cache_thoughts_e2e.py. Launches two bbq-8b-mid3 servers (with/without flag), drives a 2-turn protocol via raw input_ids, asserts the cached_tokens delta on turn 2. Passes on agentic-rl image.
  • The pre-existing test/registered/distributed/test_parallelism_context_integration.py lacks register_cuda_ci(...) and breaks test/run_suite.py collection — separate concern, not introduced by this PR.

Known limitations / follow-up

  • Chat-completions text path doesn't benefit (BPE retokenization breaks alignment regardless of the flag). The flag is a TITO-protocol feature; PR description above documents the required client behavior.
  • 6 non-RadixCache backends only stub the split= kwarg (warn + fall back). Full per-backend implementation deferred until empirical results justify the engineering.
  • The mid3-bundled chat_template.jinja is stale relative to LLM360/bbq-chat-template:main (the bundled copy has a 'think' default for think_tag that the upstream removed). E2E uses an explicit --chat-template override pointing at a vendored upstream copy under test/manual/chat_templates/bbq_upstream.jinja. If/when Mid3 ships with the upstream template, the override can be dropped.

…x entries

Adds the data-model and helper pieces needed to skip reasoning tokens from the
shared prefix cache while preserving thought-infused answer K/V across turns.

* server_args: --no-cache-thoughts boolean flag
* schedule_batch.Req: answer_start_position field, set by update_reasoning_tokens
  when the </think> boundary is detected
* base_prefix_cache: original_positions on InsertParams and MatchResult
  (backwards-compat default: None == today's contiguous-position behavior)
* radix_cache (RadixCache only): round-trip per-token positions through insert,
  match_prefix, _split_node, and TreeNode
* common.split_kv_for_no_cache_thoughts: pure helper that splits a finished
  request's KV into the radix-bound slice (prompt + post-</think> answer with
  original positions) and the thought slice to free directly

11 unit tests cover the new behavior. Remaining work (not in this commit):
wiring release_kv_cache to invoke the helper, non-contiguous positions in the
scheduler's compute_position/clamp_position paths, schema propagation to the
6 non-RadixCache backends, and the e2e server-fixture test.
Extends the foundation toward --no-cache-thoughts being end-to-end functional:

* RadixCache.cache_finished_req(split=...): when a NoCacheThoughtsSplit is
  passed, use its virtual token ids / kv_indices / positions for the radix
  insert instead of looking them up from the per-request KV slot; free the
  thought slice directly.
* common.release_kv_cache: when --no-cache-thoughts is on and the request has
  a recorded answer_start_position, build the split via
  split_kv_for_no_cache_thoughts and route it into cache_finished_req.
* common.derive_extend_position_start: helper that turns per-request cached
  RoPE positions (returned by match_prefix on cache hits) into the per-request
  starting position for extend tokens. Returns None when no request has cached
  positions (signals legacy contiguous behavior).
* forward_batch_info.compute_position(_torch): new extend_position_start kwarg
  overrides the implicit "positions start at extend_prefix_lens" assumption,
  enabling non-contiguous positions on cache hits. Triton path falls through
  to the torch path when the override is set; native triton support deferred.

8 new unit tests (19 total, all green). Not yet wired: ForwardBatch.init_new
call site that would actually pass extend_position_start, Req-side cached
positions tracking, and the decode-path clamp_position changes.
Wires the call site so per-request cached non-contiguous positions feed into
the positions tensor handed to the model. Behavior is identical to today when
batch.cached_positions_per_req is None (the default), since
build_extend_positions falls through to compute_position in that case.

* build_extend_positions: new helper in forward_batch_info that bridges
  scheduler state (per-req cached positions, prefix lens) to compute_position's
  extend_position_start tensor. Lives next to compute_position so the call
  site only needs to swap which function it calls.
* ForwardBatch.init_new: extend-mode path now calls build_extend_positions,
  reading getattr(batch, "cached_positions_per_req", None). Uses getattr so
  pre-existing ScheduleBatch instances without the field continue to work.

The matching field on ScheduleBatch (cached_positions_per_req) and its
population from match_prefix results is the next integration step; this
commit is a no-op until that field is wired.

2 new unit tests on build_extend_positions (21 total, all green). The
test_compute_position_noncontig hard-coded backend name moved from "aiter"
to "torch_native" because support_triton() returns True for everything except
{torch_native, intel_amx, ascend} and the test must force the torch path.
Closes the gap between the radix tree (which returns non-contiguous positions on
cache hits) and ForwardBatch.init_new (which consumes them via the previously
wired build_extend_positions helper).

* Req.cached_positions: new field. init_next_round_input now stores
  match_result.original_positions on the Req alongside the existing prefix_indices
  unpacking. None when the cache hit was on a legacy entry without positions.
* ScheduleBatch.cached_positions_per_req: new field. prepare_for_extend
  populates it via collect_cached_positions(reqs), which aggregates per-req
  cached_positions or returns None if no req has any (signaling that
  ForwardBatch should take the legacy contiguous-positions path).
* collect_cached_positions: module-level helper, isolated so it can be unit-
  tested without standing up a real ScheduleBatch.

3 new unit tests (24 total, all green). The prefill cache-hit path is now
wired end-to-end. Remaining: decode-path clamp_position must use max_position
+ 1 (not seq_len - 1) so per-token positions during decode continue from
where the cached prefix left off.
After a prefill cache hit whose cached entry carried non-contiguous RoPE
positions, the request's decode tokens must continue from max(cached) + extend
+ 1, not from seq_len - 1. This commit closes that gap.

* clamp_position(seq_lens, position_offsets=None): new optional offsets arg.
  When provided, output is clamp(seq_lens - 1) + offsets. Behavior unchanged
  when offsets is None. The CUDA-path wrapper is a thin Python add on top of
  the existing jit_kernel; the CUDA kernel itself is untouched.
* common.derive_position_offsets: helper computing per-request offset as
  max(cached_positions) - (prefix_len - 1), or 0 for requests without cached
  positions. Returns None if no req in the batch needs an offset.
* ScheduleBatch.position_offsets: new tensor field, populated in
  prepare_for_extend alongside cached_positions_per_req.
* ForwardBatch.init_new (decode path): passes batch.position_offsets through
  to clamp_position via getattr (legacy callers still work).

4 new unit tests covering clamp_position offsets and the offsets helper
(28 total, all green).
Every prefix-cache backend's cache_finished_req now accepts the split kwarg so
release_kv_cache won't raise TypeError when --no-cache-thoughts fires on a
non-RadixCache backend. The behavior on those backends is a one-time warning
plus fall-through to the default cache_finished_req (thoughts get cached as
usual). Full split-insertion support per backend is deferred until empirical
results justify the engineering.

* chunk_cache.ChunkCache, swa_radix_cache.SWARadixCache,
  mamba_radix_cache.MambaRadixCache, radix_cache_cpp.RadixCacheCpp: accept
  split=None; warn-and-ignore if non-None.
* lmc_radix_cache.LMCRadixCache: accept split=None, forward to super (which
  is RadixCache, fully implemented); warn that the LMCache offload that
  follows reads req.fill_ids directly and may offload more tokens than the
  radix retained.
* session_aware_cache.SessionAwareCache: already accepted **kwargs and forwards
  to inner, so it works transparently once the inner backend handles split.
* common.warn_split_unsupported_once: module-level one-time warning helper so
  the noise per backend is bounded.

6 new tests across 2 files (34 total; 2 skipped because C++ extensions can't
JIT-compile in the test env — source verified by py_compile and reads).
* test/manual/test_bbq_smoke.py: confirms BBQ-8B-Mid3 loads in the agentic-rl
  container image, the K2-v3 reasoning parser separates <think>...</think>
  from the answer (290 reasoning tokens, 222 answer tokens on a sample prompt),
  and the /v1/chat/completions endpoint applies the chat template (which
  primes the assistant turn with <think>\n).
* test/manual/test_no_cache_thoughts_e2e.py: two-server cached_tokens-delta
  test. Confirms the --no-cache-thoughts split-insert path fires correctly
  (logged "split fired rid=... input_len=18 output_len=128 answer_start=144
  virtual_tokens=20 thoughts_freed=126" during diagnostic runs).

E2E test currently fails the assertion (cached_tokens=22 on both servers).
Root cause: the BBQ chat template renders a prior assistant message with an
EMPTY <think>\n</think>\n block even when reasoning_content is empty, instead
of omitting the block. So turn 2's tokens after the user prompt are
[<think>, \n, </think>, \n, answer...], while the cached path is
[prompt..., answer...] with no <think></think> tokens between them. Both
servers' prefix match diverges at the same slot — at the empty <think> marker
— so neither sees the answer slice as cached.

The feature implementation is correct; the chat template doesn't align with
the no-cache-thoughts cached path for multi-turn rendering. Resolution:
either (a) modify the chat template to drop the empty <think></think> block
when thinking content is empty, or (b) extend the split path to cover the
<think>...</think> boundary tokens with synthetic positions. Tracked as
follow-up.
Adds the missing piece for multi-turn cache alignment: the chat template's
priming tail (typically <think>\n added by add_generation_prompt) is part of
turn 1's input but NOT of turn 2's chat-template-rendered history. Caching it
breaks the cache match at the assistant header. The fix scans the tail of
origin_input_ids for the last <think> token and excludes everything from
there onward from the virtual cached prompt.

* split_kv_for_no_cache_thoughts: new think_start_id kwarg. When provided
  scans backward through origin_input_ids for the last occurrence and uses
  that index as prompt_keep_len; tokens beyond are dropped from the cached
  entry. thought_kv_indices_to_free is also restricted to positions
  [input_len .. answer_start) so the priming slots (which are still owned by
  the radix entry that cache_unfinished_req inserted at prefill time) are
  not double-freed — that would trigger "token_to_kv_pool_allocator memory
  leak detected" on the next free check.
* scheduler.Scheduler: encode <think> alongside </think> at init and stash
  self._think_start_id; copy onto each Req at construction so release_kv_cache
  (which has no scheduler reference) can read it.
* release_kv_cache: pass req._think_start_id to split_kv_for_no_cache_thoughts.
* Tests: new TDD cycle covering the priming-strip path on both the helper
  (test_no_cache_thoughts_split.py) and the routing wrapper
  (test_release_kv_cache_routing.py). All 35 unit tests green; the 14
  new-style unit tests under test/registered/radix_cache/ now carry
  register_cuda_ci entries so they're picked up by test/run_suite.py.

E2E test (test_no_cache_thoughts_e2e.py): now runs cleanly end-to-end (no
crash, no leak) inside the agentic-rl image with the upstream BBQ chat
template (test/manual/chat_templates/bbq_upstream.jinja) overriding Mid3's
stale bundled copy. The assertion still fails because of a remaining
tokenization-alignment issue between turn-1 model-generated answer tokens
(possibly starting with leading whitespace/newline after </think>) and
turn-2 chat-template-rendered answer tokens — separate concern from the
implementation completed here.
… TITO E2E

The earlier priming-strip logic (stash think_start_id on Req, scan
origin_input_ids for the last <think>, drop the tail from the virtual
cached prompt) was the wrong call for the production rollout path.

In TITO ("Token-In, Token-Out") multi-turn rollouts, turn N+1's input is
constructed from turn N's prompt_token_ids verbatim — the priming
<think>\n at the end of the input is carried over as-is. The cached
entry from no-cache-thoughts must therefore also include the priming
tail; otherwise turn N+1's input has [...prompt, <think>, \n, answer...]
while the cached path has [...prompt, answer...] — match dies right
where we want it to extend through.

Stripping the priming was an attempt to align with the OpenAI chat-
completions text path, where the chat template strips reasoning from
historical assistant messages. But that path also has a BPE retokenize
mismatch that prevents the answer slice from aligning anyway (see TITO
doc). So the chat-template path can't deliver multi-turn answer reuse
under any priming-strip variant. The correct integration path is TITO,
where keeping the priming in the cache is the right answer.

Net change:
* split_kv_for_no_cache_thoughts: removes think_start_id kwarg and the
  scan-for-priming logic. prompt_keep_len is always input_len.
* common.release_kv_cache: drops the think_start_id propagation.
* scheduler.Scheduler: drops self._think_start_id init and the per-Req
  _think_start_id assignments at the three Req-construction sites.
* Test removals: the priming-strip-specific cases in
  test_no_cache_thoughts_split.py and test_release_kv_cache_routing.py.
* New TITO E2E: test/manual/test_no_cache_thoughts_e2e.py now sends
  turn 2 with input_ids in the chat-completions body (bypassing chat
  template), strips output_token_ids[:reasoning_tokens] from the prior
  turn before appending to the buffer, and tokenizes only the env-delta
  (new user msg + assistant generation prompt). Assertion now passes:
  baseline cached_tokens=22, --no-cache-thoughts cached_tokens=191
  (delta = answer length, exactly the multi-turn cache reuse the
  feature is supposed to deliver).

All 34 unit tests still green (1 env-skipped for the C++ extension).
@DavidBellamy
Copy link
Copy Markdown
Collaborator Author

Mapping this to the trainer side: agentic inference looks like: prefill phase -> decode phase -> environment -> prefill phase -> decode phase -> environment -> ...

For turn i in [0, 1, 2, ...] let P_i := prefill phase, D_i := decode phase, E_i := environment phase for turn i.

Example:
P_0 = [sys_msg, usr_prompt], more broadly P_i is the full trajectory history up to but not including D_i
D_0 = [think, ans] conditioned on P_0. Note that tool calls are parsed from ans.
E_0 = [tool_output] added to [P_0, D_0]

In inference engines like sglang and vllm, when the prefill phase is entered, the global KV pool (shared across all concurrent inference requests) is checked to see if those tokens' KV vectors are cached. The engine matches the largest prefix from P_i that it can find in the KV pool and returns that, then prefills the remaining suffix of P_i.

D_i then begins, conditioned on P_i. During decode, a local KV cache (scoped per inference request) is constructed for the growing decoded sequence. The global KV pool is not added to during decode. Under standard protocol - once decode finishes, the full decoded sequence is inserted into the global KV pool.

#29 alters only that last step. Once decode finishes, only the ans (i.e. non-think) tokens are added to the global KV pool. So the global KV pool for each request is updated to contain P_i + ans after one request's decode phase completes.

Note that the positional encodings (e.g. RoPE) for the ans token activations are not shifted in my current implementation. So if there are N think tokens in the first decode step, the first ans token's activation will have its original position of '6' embedded into it.

Also note that the decode phase within each turn is conditioning on the think tokens from that same turn during generation, as in standard inference.

So on the trainer side, let's say we have an agentic trajectory and want to attention mask it so the model learns how to handle this inference strategy. For all tokens in turn N, all think tokens from turns N-1, N-2, ..., 0 should be masked. That should do the trick!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant