feat: --no-cache-thoughts for reasoning-model multi-turn cache alignment#29
feat: --no-cache-thoughts for reasoning-model multi-turn cache alignment#29DavidBellamy wants to merge 9 commits into
Conversation
…x entries Adds the data-model and helper pieces needed to skip reasoning tokens from the shared prefix cache while preserving thought-infused answer K/V across turns. * server_args: --no-cache-thoughts boolean flag * schedule_batch.Req: answer_start_position field, set by update_reasoning_tokens when the </think> boundary is detected * base_prefix_cache: original_positions on InsertParams and MatchResult (backwards-compat default: None == today's contiguous-position behavior) * radix_cache (RadixCache only): round-trip per-token positions through insert, match_prefix, _split_node, and TreeNode * common.split_kv_for_no_cache_thoughts: pure helper that splits a finished request's KV into the radix-bound slice (prompt + post-</think> answer with original positions) and the thought slice to free directly 11 unit tests cover the new behavior. Remaining work (not in this commit): wiring release_kv_cache to invoke the helper, non-contiguous positions in the scheduler's compute_position/clamp_position paths, schema propagation to the 6 non-RadixCache backends, and the e2e server-fixture test.
Extends the foundation toward --no-cache-thoughts being end-to-end functional: * RadixCache.cache_finished_req(split=...): when a NoCacheThoughtsSplit is passed, use its virtual token ids / kv_indices / positions for the radix insert instead of looking them up from the per-request KV slot; free the thought slice directly. * common.release_kv_cache: when --no-cache-thoughts is on and the request has a recorded answer_start_position, build the split via split_kv_for_no_cache_thoughts and route it into cache_finished_req. * common.derive_extend_position_start: helper that turns per-request cached RoPE positions (returned by match_prefix on cache hits) into the per-request starting position for extend tokens. Returns None when no request has cached positions (signals legacy contiguous behavior). * forward_batch_info.compute_position(_torch): new extend_position_start kwarg overrides the implicit "positions start at extend_prefix_lens" assumption, enabling non-contiguous positions on cache hits. Triton path falls through to the torch path when the override is set; native triton support deferred. 8 new unit tests (19 total, all green). Not yet wired: ForwardBatch.init_new call site that would actually pass extend_position_start, Req-side cached positions tracking, and the decode-path clamp_position changes.
Wires the call site so per-request cached non-contiguous positions feed into
the positions tensor handed to the model. Behavior is identical to today when
batch.cached_positions_per_req is None (the default), since
build_extend_positions falls through to compute_position in that case.
* build_extend_positions: new helper in forward_batch_info that bridges
scheduler state (per-req cached positions, prefix lens) to compute_position's
extend_position_start tensor. Lives next to compute_position so the call
site only needs to swap which function it calls.
* ForwardBatch.init_new: extend-mode path now calls build_extend_positions,
reading getattr(batch, "cached_positions_per_req", None). Uses getattr so
pre-existing ScheduleBatch instances without the field continue to work.
The matching field on ScheduleBatch (cached_positions_per_req) and its
population from match_prefix results is the next integration step; this
commit is a no-op until that field is wired.
2 new unit tests on build_extend_positions (21 total, all green). The
test_compute_position_noncontig hard-coded backend name moved from "aiter"
to "torch_native" because support_triton() returns True for everything except
{torch_native, intel_amx, ascend} and the test must force the torch path.
Closes the gap between the radix tree (which returns non-contiguous positions on cache hits) and ForwardBatch.init_new (which consumes them via the previously wired build_extend_positions helper). * Req.cached_positions: new field. init_next_round_input now stores match_result.original_positions on the Req alongside the existing prefix_indices unpacking. None when the cache hit was on a legacy entry without positions. * ScheduleBatch.cached_positions_per_req: new field. prepare_for_extend populates it via collect_cached_positions(reqs), which aggregates per-req cached_positions or returns None if no req has any (signaling that ForwardBatch should take the legacy contiguous-positions path). * collect_cached_positions: module-level helper, isolated so it can be unit- tested without standing up a real ScheduleBatch. 3 new unit tests (24 total, all green). The prefill cache-hit path is now wired end-to-end. Remaining: decode-path clamp_position must use max_position + 1 (not seq_len - 1) so per-token positions during decode continue from where the cached prefix left off.
After a prefill cache hit whose cached entry carried non-contiguous RoPE positions, the request's decode tokens must continue from max(cached) + extend + 1, not from seq_len - 1. This commit closes that gap. * clamp_position(seq_lens, position_offsets=None): new optional offsets arg. When provided, output is clamp(seq_lens - 1) + offsets. Behavior unchanged when offsets is None. The CUDA-path wrapper is a thin Python add on top of the existing jit_kernel; the CUDA kernel itself is untouched. * common.derive_position_offsets: helper computing per-request offset as max(cached_positions) - (prefix_len - 1), or 0 for requests without cached positions. Returns None if no req in the batch needs an offset. * ScheduleBatch.position_offsets: new tensor field, populated in prepare_for_extend alongside cached_positions_per_req. * ForwardBatch.init_new (decode path): passes batch.position_offsets through to clamp_position via getattr (legacy callers still work). 4 new unit tests covering clamp_position offsets and the offsets helper (28 total, all green).
Every prefix-cache backend's cache_finished_req now accepts the split kwarg so release_kv_cache won't raise TypeError when --no-cache-thoughts fires on a non-RadixCache backend. The behavior on those backends is a one-time warning plus fall-through to the default cache_finished_req (thoughts get cached as usual). Full split-insertion support per backend is deferred until empirical results justify the engineering. * chunk_cache.ChunkCache, swa_radix_cache.SWARadixCache, mamba_radix_cache.MambaRadixCache, radix_cache_cpp.RadixCacheCpp: accept split=None; warn-and-ignore if non-None. * lmc_radix_cache.LMCRadixCache: accept split=None, forward to super (which is RadixCache, fully implemented); warn that the LMCache offload that follows reads req.fill_ids directly and may offload more tokens than the radix retained. * session_aware_cache.SessionAwareCache: already accepted **kwargs and forwards to inner, so it works transparently once the inner backend handles split. * common.warn_split_unsupported_once: module-level one-time warning helper so the noise per backend is bounded. 6 new tests across 2 files (34 total; 2 skipped because C++ extensions can't JIT-compile in the test env — source verified by py_compile and reads).
* test/manual/test_bbq_smoke.py: confirms BBQ-8B-Mid3 loads in the agentic-rl container image, the K2-v3 reasoning parser separates <think>...</think> from the answer (290 reasoning tokens, 222 answer tokens on a sample prompt), and the /v1/chat/completions endpoint applies the chat template (which primes the assistant turn with <think>\n). * test/manual/test_no_cache_thoughts_e2e.py: two-server cached_tokens-delta test. Confirms the --no-cache-thoughts split-insert path fires correctly (logged "split fired rid=... input_len=18 output_len=128 answer_start=144 virtual_tokens=20 thoughts_freed=126" during diagnostic runs). E2E test currently fails the assertion (cached_tokens=22 on both servers). Root cause: the BBQ chat template renders a prior assistant message with an EMPTY <think>\n</think>\n block even when reasoning_content is empty, instead of omitting the block. So turn 2's tokens after the user prompt are [<think>, \n, </think>, \n, answer...], while the cached path is [prompt..., answer...] with no <think></think> tokens between them. Both servers' prefix match diverges at the same slot — at the empty <think> marker — so neither sees the answer slice as cached. The feature implementation is correct; the chat template doesn't align with the no-cache-thoughts cached path for multi-turn rendering. Resolution: either (a) modify the chat template to drop the empty <think></think> block when thinking content is empty, or (b) extend the split path to cover the <think>...</think> boundary tokens with synthetic positions. Tracked as follow-up.
Adds the missing piece for multi-turn cache alignment: the chat template's priming tail (typically <think>\n added by add_generation_prompt) is part of turn 1's input but NOT of turn 2's chat-template-rendered history. Caching it breaks the cache match at the assistant header. The fix scans the tail of origin_input_ids for the last <think> token and excludes everything from there onward from the virtual cached prompt. * split_kv_for_no_cache_thoughts: new think_start_id kwarg. When provided scans backward through origin_input_ids for the last occurrence and uses that index as prompt_keep_len; tokens beyond are dropped from the cached entry. thought_kv_indices_to_free is also restricted to positions [input_len .. answer_start) so the priming slots (which are still owned by the radix entry that cache_unfinished_req inserted at prefill time) are not double-freed — that would trigger "token_to_kv_pool_allocator memory leak detected" on the next free check. * scheduler.Scheduler: encode <think> alongside </think> at init and stash self._think_start_id; copy onto each Req at construction so release_kv_cache (which has no scheduler reference) can read it. * release_kv_cache: pass req._think_start_id to split_kv_for_no_cache_thoughts. * Tests: new TDD cycle covering the priming-strip path on both the helper (test_no_cache_thoughts_split.py) and the routing wrapper (test_release_kv_cache_routing.py). All 35 unit tests green; the 14 new-style unit tests under test/registered/radix_cache/ now carry register_cuda_ci entries so they're picked up by test/run_suite.py. E2E test (test_no_cache_thoughts_e2e.py): now runs cleanly end-to-end (no crash, no leak) inside the agentic-rl image with the upstream BBQ chat template (test/manual/chat_templates/bbq_upstream.jinja) overriding Mid3's stale bundled copy. The assertion still fails because of a remaining tokenization-alignment issue between turn-1 model-generated answer tokens (possibly starting with leading whitespace/newline after </think>) and turn-2 chat-template-rendered answer tokens — separate concern from the implementation completed here.
… TITO E2E
The earlier priming-strip logic (stash think_start_id on Req, scan
origin_input_ids for the last <think>, drop the tail from the virtual
cached prompt) was the wrong call for the production rollout path.
In TITO ("Token-In, Token-Out") multi-turn rollouts, turn N+1's input is
constructed from turn N's prompt_token_ids verbatim — the priming
<think>\n at the end of the input is carried over as-is. The cached
entry from no-cache-thoughts must therefore also include the priming
tail; otherwise turn N+1's input has [...prompt, <think>, \n, answer...]
while the cached path has [...prompt, answer...] — match dies right
where we want it to extend through.
Stripping the priming was an attempt to align with the OpenAI chat-
completions text path, where the chat template strips reasoning from
historical assistant messages. But that path also has a BPE retokenize
mismatch that prevents the answer slice from aligning anyway (see TITO
doc). So the chat-template path can't deliver multi-turn answer reuse
under any priming-strip variant. The correct integration path is TITO,
where keeping the priming in the cache is the right answer.
Net change:
* split_kv_for_no_cache_thoughts: removes think_start_id kwarg and the
scan-for-priming logic. prompt_keep_len is always input_len.
* common.release_kv_cache: drops the think_start_id propagation.
* scheduler.Scheduler: drops self._think_start_id init and the per-Req
_think_start_id assignments at the three Req-construction sites.
* Test removals: the priming-strip-specific cases in
test_no_cache_thoughts_split.py and test_release_kv_cache_routing.py.
* New TITO E2E: test/manual/test_no_cache_thoughts_e2e.py now sends
turn 2 with input_ids in the chat-completions body (bypassing chat
template), strips output_token_ids[:reasoning_tokens] from the prior
turn before appending to the buffer, and tokenizes only the env-delta
(new user msg + assistant generation prompt). Assertion now passes:
baseline cached_tokens=22, --no-cache-thoughts cached_tokens=191
(delta = answer length, exactly the multi-turn cache reuse the
feature is supposed to deliver).
All 34 unit tests still green (1 env-skipped for the C++ extension).
|
Mapping this to the trainer side: agentic inference looks like: prefill phase -> decode phase -> environment -> prefill phase -> decode phase -> environment -> ... For turn i in [0, 1, 2, ...] let P_i := prefill phase, D_i := decode phase, E_i := environment phase for turn i. Example: In inference engines like sglang and vllm, when the prefill phase is entered, the global KV pool (shared across all concurrent inference requests) is checked to see if those tokens' KV vectors are cached. The engine matches the largest prefix from P_i that it can find in the KV pool and returns that, then prefills the remaining suffix of P_i. D_i then begins, conditioned on P_i. During decode, a local KV cache (scoped per inference request) is constructed for the growing decoded sequence. The global KV pool is not added to during decode. Under standard protocol - once decode finishes, the full decoded sequence is inserted into the global KV pool. #29 alters only that last step. Once decode finishes, only the Note that the positional encodings (e.g. RoPE) for the ans token activations are not shifted in my current implementation. So if there are N think tokens in the first decode step, the first ans token's activation will have its original position of '6' embedded into it. Also note that the decode phase within each turn is conditioning on the think tokens from that same turn during generation, as in standard inference. So on the trainer side, let's say we have an agentic trajectory and want to attention mask it so the model learns how to handle this inference strategy. For all tokens in turn N, all think tokens from turns N-1, N-2, ..., 0 should be masked. That should do the trick! |
Summary
Adds a
--no-cache-thoughtsserver flag that, for reasoning requests, inserts the answer slice into the prefix cache at the original RoPE positions (skipping the thought slice) instead of caching the entire generated trajectory at contiguous positions. Lets a TITO multi-turn rollout reuse turn N's answer KV as a prefix match on turn N+1, without polluting the radix tree with request-specific reasoning content.End-to-end validated on
bbq-8b-mid3in the agentic-rl image:cached_tokens = 22(prompt prefix only)--no-cache-thoughts— turn 2cached_tokens = 191(prompt + answer, delta = answer length)How it works (server side)
</think>Req.update_reasoning_tokens(schedule_batch.py)req.answer_start_positionwhen the reasoning parser emits its end tokensplit_kv_for_no_cache_thoughts(common.py)(virtual_token_ids, virtual_kv_indices, virtual_positions, thought_kv_indices_to_free)— input prompt + post-</think>answer, with answer carrying its original (non-contiguous) positions; thought slots earmarked for freerelease_kv_cache(common.py)cache_finished_req(split=…)RadixCache.cache_finished_req(radix_cache.py)original_positionsattached, frees the thought slice directlyRadixCache.{insert,_split_node,_match_prefix_helper,match_prefix},TreeNode.positions,InsertParams.original_positions,MatchResult.original_positionsmatch_prefixreturns it on cache hitsReqReq.init_next_round_input(schedule_batch.py)match_result.original_positionsonreq.cached_positionscollect_cached_positions,prepare_for_extendScheduleBatch.cached_positions_per_req+ScheduleBatch.position_offsetspopulatedbuild_extend_positions,clamp_position(forward_batch_info.py)max(cached) + 1ChunkCache,SWARadixCache,MambaRadixCache,RadixCacheCpp,LMCRadixCache,SessionAwareCachesplit=kwarg; non-RadixCachebackends log a one-time warning and fall back to standardcache_finished_req(graceful no-op)Integration spec for callers (TITO multi-turn rollouts)
The chat-completions text path cannot deliver multi-turn answer reuse even with this flag, because the chat template re-tokenizes prior assistant content via BPE and the first answer token ID after the assistant header doesn't match the model's emitted ID (leading-whitespace BPE merges differ). Callers must use the TITO protocol:
Turn 1: send via
/v1/chat/completionswithreturn_prompt_token_ids=True,return_completion_token_ids=True,return_meta_info=True. Capture from the response:choices[0].prompt_token_ids— turn 1's input IDschoices[0].completion_token_ids— turn 1's output IDs (full, including thoughts)usage.reasoning_tokens— count of thought tokens (including</think>)Strip the thoughts before appending to the running buffer:
```python
answer_ids = completion_token_ids[reasoning_tokens:]
running_ids = prompt_token_ids + answer_ids
```
The buffer must NOT contain the thought slice; the cached entry doesn't.
Turn N+1: construct
input_ids = running_ids + tokenize(env_delta)whereenv_deltais the new user message + new assistant priming, including the K2V3 boundary\n(model stops on<|im_end|>but the chat template emits<|im_end|>\n, so the delta starts with\n). Send via/v1/chat/completionswithinput_idsset in the request body (chat template is bypassed wheninput_idsis provided).Verify cache hits: turn N+1's
usage.prompt_tokens_details.cached_tokensshould grow turn-over-turn by the prior answer's length.For RL360: this is the same
tito-patchedflow already used bymiles/rollout/generate_hub/agentic_tool_call.py; no client changes needed beyond ensuring the thought-stripping step runs for--no-cache-thoughts-enabled servers.Tests
test/registered/radix_cache/(1 env-skipped for C++ extension). Covers: CLI flag,Req.answer_start_position, split helper, radix position round-trip,release_kv_cacherouting,compute_position/clamp_positionnon-contiguous support, derive helpers,Req.cached_positions,ScheduleBatch.cached_positions_per_req, and the 6 non-RadixCachebackend signature acceptance.test/manual/test_no_cache_thoughts_e2e.py. Launches twobbq-8b-mid3servers (with/without flag), drives a 2-turn protocol via rawinput_ids, asserts the cached_tokens delta on turn 2. Passes on agentic-rl image.test/registered/distributed/test_parallelism_context_integration.pylacksregister_cuda_ci(...)and breakstest/run_suite.pycollection — separate concern, not introduced by this PR.Known limitations / follow-up
RadixCachebackends only stub thesplit=kwarg (warn + fall back). Full per-backend implementation deferred until empirical results justify the engineering.chat_template.jinjais stale relative toLLM360/bbq-chat-template:main(the bundled copy has a'think'default forthink_tagthat the upstream removed). E2E uses an explicit--chat-templateoverride pointing at a vendored upstream copy undertest/manual/chat_templates/bbq_upstream.jinja. If/when Mid3 ships with the upstream template, the override can be dropped.