From 1c6b2b18ccf9b92806325f03893c1ffb7c372947 Mon Sep 17 00:00:00 2001 From: "jinwei.han" Date: Sun, 19 Apr 2026 23:23:22 -0700 Subject: [PATCH] server : do not cap slot context to training context The per-slot cap overrides the user-requested context size even when it was explicitly extended via RoPE scaling (YaRN), which is the whole point of YaRN-aware models such as Qwen3. The KV cache is already allocated for the full n_ctx_seq, so capping slot.n_ctx only throws away addressable cells that the user paid memory for. llama_context already warns about "possible training context overflow" when n_ctx_seq > n_ctx_train, so dropping the server-side cap keeps the safety signal without silently ignoring --ctx-size. Closes #22140 --- tools/server/server-context.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e134b3cfb26..6f5fead9cec 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -759,10 +759,11 @@ struct server_context_impl { const int n_ctx_train = llama_model_n_ctx_train(model); - int n_ctx_slot = llama_n_ctx_seq(ctx); + const int n_ctx_slot = llama_n_ctx_seq(ctx); if (n_ctx_slot > n_ctx_train) { - SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); - n_ctx_slot = n_ctx_train; + SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - " + "generation quality may degrade beyond the training context unless RoPE scaling is configured\n", + n_ctx_slot, n_ctx_train); } slots.clear();