Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 116 additions & 5 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -3633,6 +3633,65 @@ static bool dsml_decode_state_uses_payload_sampling(dsml_decode_state state) {
return state == DSML_DECODE_STRING_BODY || state == DSML_DECODE_JSON_STRING;
}

/* Tool syntax is plain generated text. Keep ordinary text generation out of
* the final context slice so a tool block opened near the end still has room to
* close before the hard session limit. This buys room for the structural
* envelope and small payloads; oversized tool arguments can still hit the hard
* limit and need a separate constrained-decoding or truncation strategy. */
#define DS4_TOOL_CALL_CONTEXT_RESERVE_TOKENS 256

static int decode_hard_token_limit(int requested_max, int room) {
if (requested_max < 0 || room <= 0) return 0;
return requested_max < room ? requested_max : room;
}

static int decode_soft_token_limit(int hard_limit, int room, bool reserve_for_tools) {
if (!reserve_for_tools || hard_limit <= 0 ||
room <= DS4_TOOL_CALL_CONTEXT_RESERVE_TOKENS)
return hard_limit;
int ordinary_room = room - DS4_TOOL_CALL_CONTEXT_RESERVE_TOKENS;
return hard_limit < ordinary_room ? hard_limit : ordinary_room;
}

static bool decode_tool_call_open(bool saw_tool_start,
bool saw_tool_end,
dsml_decode_state state) {
return (saw_tool_start && !saw_tool_end) ||
dsml_decode_state_is_tool(state);
}

static bool dsml_suffix_partial_tool_start(const char *raw, size_t raw_len) {
if (!raw || raw_len == 0) return false;
for (size_t i = 0; i < sizeof(dsml_syntaxes) / sizeof(dsml_syntaxes[0]); i++) {
const char *lit = dsml_syntaxes[i].tool_calls_start;
size_t lit_len = strlen(lit);
/* Complete markers are handled by the normal DSML scanner. */
size_t max = raw_len < lit_len ? raw_len : lit_len - 1;
for (size_t n = 1; n <= max; n++) {
if (!memcmp(raw + raw_len - n, lit, n)) return true;
}
}
return false;
}

static bool decode_tool_reserve_active(bool saw_tool_start,
bool saw_tool_end,
dsml_decode_state state,
const char *text,
size_t text_len,
bool soft_limit_reached) {
return decode_tool_call_open(saw_tool_start, saw_tool_end, state) ||
(soft_limit_reached && dsml_suffix_partial_tool_start(text, text_len));
}

static int decode_remaining_for_state(int completion,
int hard_limit,
int soft_limit,
bool tool_call_open) {
int limit = tool_call_open ? hard_limit : soft_limit;
return completion < limit ? limit - completion : 0;
}

static void dsml_decode_tracker_init(dsml_decode_tracker *dt) {
memset(dt, 0, sizeof(*dt));
dt->mode = DSML_TRACK_SEARCH;
Expand Down Expand Up @@ -7086,8 +7145,10 @@ static void generate_job(server *s, job *j) {
size_t stop_scan_from = 0;
const char *finish = "length";
int completion = 0;
int max_tokens = j->req.max_tokens;
int room = ds4_session_ctx(s->session) - ds4_session_pos(s->session);
const bool tool_budget = j->req.kind == REQ_CHAT && j->req.has_tools;
int max_tokens = decode_hard_token_limit(j->req.max_tokens, room);
int soft_max_tokens = decode_soft_token_limit(max_tokens, room, tool_budget);
bool saw_tool_start = false;
bool saw_tool_end = false;
bool saw_orphan_tool_end = false;
Expand All @@ -7096,9 +7157,8 @@ static void generate_job(server *s, job *j) {
int next_decode_log = 50;
uint64_t rng = j->req.seed ? j->req.seed :
(((uint64_t)time(NULL) << 32) ^ ((uint64_t)s->seq << 1) ^ (uint64_t)(uintptr_t)j);
if (max_tokens < 0) max_tokens = 0;
if (max_tokens > room) max_tokens = room;
trace_event(s, trace_id, "prefill done; decode_max=%d ctx_room=%d", max_tokens, room);
trace_event(s, trace_id, "prefill done; decode_max=%d decode_soft=%d ctx_room=%d",
max_tokens, soft_max_tokens, room);
const double decode_t0 = now_sec();
double last_decode_log_t = decode_t0;
int last_decode_log_completion = 0;
Expand All @@ -7111,6 +7171,16 @@ static void generate_job(server *s, job *j) {
dsml_decode_state dsml_state = j->req.kind == REQ_CHAT && j->req.has_tools ?
dsml_tracker.decode : DSML_DECODE_OUTSIDE;
const bool in_tool_call = dsml_decode_state_is_tool(dsml_state);
const bool tool_call_open = tool_budget &&
decode_tool_reserve_active(saw_tool_start,
saw_tool_end,
dsml_state,
text.ptr,
text.len,
completion >= soft_max_tokens);
const int decode_remaining =
decode_remaining_for_state(completion, max_tokens, soft_max_tokens, tool_call_open);
if (decode_remaining <= 0) break;
if (!(j->req.kind == REQ_CHAT && j->req.has_tools && (saw_tool_start || in_tool_call))) {
kv_cache_maybe_store_continued(s);
}
Expand Down Expand Up @@ -7141,7 +7211,7 @@ static void generate_job(server *s, job *j) {
{
ntok = ds4_session_eval_speculative_argmax(s->session,
token,
max_tokens - completion,
decode_remaining,
ds4_token_eos(s->engine),
toks,
(int)(sizeof(toks) / sizeof(toks[0])),
Expand Down Expand Up @@ -9422,6 +9492,46 @@ static void test_dsml_decode_state_separates_structure_and_payload(void) {
TEST_ASSERT(tracker.decode == DSML_DECODE_OUTSIDE);
}

static void test_decode_budget_reserves_tool_call_room(void) {
TEST_ASSERT(decode_hard_token_limit(393216, 1280) == 1280);
TEST_ASSERT(decode_hard_token_limit(128, 1280) == 128);
TEST_ASSERT(decode_hard_token_limit(-1, 1280) == 0);

TEST_ASSERT(decode_soft_token_limit(1280, 1280, false) == 1280);
TEST_ASSERT(decode_soft_token_limit(1280, 1280, true) == 1024);
TEST_ASSERT(decode_soft_token_limit(128, 1280, true) == 128);
TEST_ASSERT(decode_soft_token_limit(200, 200, true) == 200);
TEST_ASSERT(decode_soft_token_limit(300, 300, true) == 44);

TEST_ASSERT(!decode_tool_call_open(false, false, DSML_DECODE_OUTSIDE));
TEST_ASSERT(decode_tool_call_open(true, false, DSML_DECODE_OUTSIDE));
TEST_ASSERT(!decode_tool_call_open(true, true, DSML_DECODE_OUTSIDE));
TEST_ASSERT(decode_tool_call_open(false, false, DSML_DECODE_STRUCTURAL));
TEST_ASSERT(dsml_suffix_partial_tool_start("<", 1));
TEST_ASSERT(dsml_suffix_partial_tool_start("\n\n" DS4_TOOL_CALLS_START_SHORT "\n<", strlen("\n\n" DS4_TOOL_CALLS_START_SHORT "\n<")));
TEST_ASSERT(!dsml_suffix_partial_tool_start(DS4_TOOL_CALLS_START, strlen(DS4_TOOL_CALLS_START)));
TEST_ASSERT(!dsml_suffix_partial_tool_start("<not_tool", strlen("<not_tool")));
TEST_ASSERT(!decode_tool_reserve_active(false, false, DSML_DECODE_OUTSIDE, "<", 1, false));
TEST_ASSERT(decode_tool_reserve_active(false, false, DSML_DECODE_OUTSIDE, "<", 1, true));

TEST_ASSERT(decode_remaining_for_state(1024, 1280, 1024, false) == 0);
TEST_ASSERT(decode_remaining_for_state(1024, 1280, 1024, true) == 256);
TEST_ASSERT(decode_remaining_for_state(1000, 1280, 1024, false) == 24);

bool reserve = decode_tool_reserve_active(false, false, DSML_DECODE_OUTSIDE,
"", 0, true);
TEST_ASSERT(!reserve);
TEST_ASSERT(decode_remaining_for_state(1024, 1280, 1024, reserve) == 0);
reserve = decode_tool_reserve_active(false, false, DSML_DECODE_OUTSIDE,
"<", 1, true);
TEST_ASSERT(reserve);
TEST_ASSERT(decode_remaining_for_state(1024, 1280, 1024, reserve) == 256);
reserve = decode_tool_reserve_active(false, false, DSML_DECODE_OUTSIDE,
"<x", 2, true);
TEST_ASSERT(!reserve);
TEST_ASSERT(decode_remaining_for_state(1024, 1280, 1024, reserve) == 0);
}

static void test_tool_memory_max_ids_prunes_oldest(void) {
const char *a_dsml = "\n\n<|DSML|tool_calls>\n<|DSML|invoke name=\"bash\">\n<|DSML|parameter name=\"command\" string=\"true\">a</|DSML|parameter>\n</|DSML|invoke>\n</|DSML|tool_calls>";
const char *b_dsml = "\n\n<|DSML|tool_calls>\n<|DSML|invoke name=\"bash\">\n<|DSML|parameter name=\"command\" string=\"true\">b</|DSML|parameter>\n</|DSML|invoke>\n</|DSML|tool_calls>";
Expand Down Expand Up @@ -10307,6 +10417,7 @@ static void ds4_server_unit_tests_run(void) {
test_tool_memory_replays_sampled_dsml();
test_exact_dsml_tool_replay_can_be_disabled();
test_dsml_decode_state_separates_structure_and_payload();
test_decode_budget_reserves_tool_call_room();
test_tool_memory_max_ids_prunes_oldest();
test_kv_tool_map_filters_by_dsml_text();
test_kv_tool_map_restores_before_prompt_render();
Expand Down