diff --git a/ds4_server.c b/ds4_server.c index bc8abbbd..8498b77f 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -3633,6 +3633,65 @@ static bool dsml_decode_state_uses_payload_sampling(dsml_decode_state state) { return state == DSML_DECODE_STRING_BODY || state == DSML_DECODE_JSON_STRING; } +/* Tool syntax is plain generated text. Keep ordinary text generation out of + * the final context slice so a tool block opened near the end still has room to + * close before the hard session limit. This buys room for the structural + * envelope and small payloads; oversized tool arguments can still hit the hard + * limit and need a separate constrained-decoding or truncation strategy. */ +#define DS4_TOOL_CALL_CONTEXT_RESERVE_TOKENS 256 + +static int decode_hard_token_limit(int requested_max, int room) { + if (requested_max < 0 || room <= 0) return 0; + return requested_max < room ? requested_max : room; +} + +static int decode_soft_token_limit(int hard_limit, int room, bool reserve_for_tools) { + if (!reserve_for_tools || hard_limit <= 0 || + room <= DS4_TOOL_CALL_CONTEXT_RESERVE_TOKENS) + return hard_limit; + int ordinary_room = room - DS4_TOOL_CALL_CONTEXT_RESERVE_TOKENS; + return hard_limit < ordinary_room ? hard_limit : ordinary_room; +} + +static bool decode_tool_call_open(bool saw_tool_start, + bool saw_tool_end, + dsml_decode_state state) { + return (saw_tool_start && !saw_tool_end) || + dsml_decode_state_is_tool(state); +} + +static bool dsml_suffix_partial_tool_start(const char *raw, size_t raw_len) { + if (!raw || raw_len == 0) return false; + for (size_t i = 0; i < sizeof(dsml_syntaxes) / sizeof(dsml_syntaxes[0]); i++) { + const char *lit = dsml_syntaxes[i].tool_calls_start; + size_t lit_len = strlen(lit); + /* Complete markers are handled by the normal DSML scanner. */ + size_t max = raw_len < lit_len ? raw_len : lit_len - 1; + for (size_t n = 1; n <= max; n++) { + if (!memcmp(raw + raw_len - n, lit, n)) return true; + } + } + return false; +} + +static bool decode_tool_reserve_active(bool saw_tool_start, + bool saw_tool_end, + dsml_decode_state state, + const char *text, + size_t text_len, + bool soft_limit_reached) { + return decode_tool_call_open(saw_tool_start, saw_tool_end, state) || + (soft_limit_reached && dsml_suffix_partial_tool_start(text, text_len)); +} + +static int decode_remaining_for_state(int completion, + int hard_limit, + int soft_limit, + bool tool_call_open) { + int limit = tool_call_open ? hard_limit : soft_limit; + return completion < limit ? limit - completion : 0; +} + static void dsml_decode_tracker_init(dsml_decode_tracker *dt) { memset(dt, 0, sizeof(*dt)); dt->mode = DSML_TRACK_SEARCH; @@ -7086,8 +7145,10 @@ static void generate_job(server *s, job *j) { size_t stop_scan_from = 0; const char *finish = "length"; int completion = 0; - int max_tokens = j->req.max_tokens; int room = ds4_session_ctx(s->session) - ds4_session_pos(s->session); + const bool tool_budget = j->req.kind == REQ_CHAT && j->req.has_tools; + int max_tokens = decode_hard_token_limit(j->req.max_tokens, room); + int soft_max_tokens = decode_soft_token_limit(max_tokens, room, tool_budget); bool saw_tool_start = false; bool saw_tool_end = false; bool saw_orphan_tool_end = false; @@ -7096,9 +7157,8 @@ static void generate_job(server *s, job *j) { int next_decode_log = 50; uint64_t rng = j->req.seed ? j->req.seed : (((uint64_t)time(NULL) << 32) ^ ((uint64_t)s->seq << 1) ^ (uint64_t)(uintptr_t)j); - if (max_tokens < 0) max_tokens = 0; - if (max_tokens > room) max_tokens = room; - trace_event(s, trace_id, "prefill done; decode_max=%d ctx_room=%d", max_tokens, room); + trace_event(s, trace_id, "prefill done; decode_max=%d decode_soft=%d ctx_room=%d", + max_tokens, soft_max_tokens, room); const double decode_t0 = now_sec(); double last_decode_log_t = decode_t0; int last_decode_log_completion = 0; @@ -7111,6 +7171,16 @@ static void generate_job(server *s, job *j) { dsml_decode_state dsml_state = j->req.kind == REQ_CHAT && j->req.has_tools ? dsml_tracker.decode : DSML_DECODE_OUTSIDE; const bool in_tool_call = dsml_decode_state_is_tool(dsml_state); + const bool tool_call_open = tool_budget && + decode_tool_reserve_active(saw_tool_start, + saw_tool_end, + dsml_state, + text.ptr, + text.len, + completion >= soft_max_tokens); + const int decode_remaining = + decode_remaining_for_state(completion, max_tokens, soft_max_tokens, tool_call_open); + if (decode_remaining <= 0) break; if (!(j->req.kind == REQ_CHAT && j->req.has_tools && (saw_tool_start || in_tool_call))) { kv_cache_maybe_store_continued(s); } @@ -7141,7 +7211,7 @@ static void generate_job(server *s, job *j) { { ntok = ds4_session_eval_speculative_argmax(s->session, token, - max_tokens - completion, + decode_remaining, ds4_token_eos(s->engine), toks, (int)(sizeof(toks) / sizeof(toks[0])), @@ -9422,6 +9492,46 @@ static void test_dsml_decode_state_separates_structure_and_payload(void) { TEST_ASSERT(tracker.decode == DSML_DECODE_OUTSIDE); } +static void test_decode_budget_reserves_tool_call_room(void) { + TEST_ASSERT(decode_hard_token_limit(393216, 1280) == 1280); + TEST_ASSERT(decode_hard_token_limit(128, 1280) == 128); + TEST_ASSERT(decode_hard_token_limit(-1, 1280) == 0); + + TEST_ASSERT(decode_soft_token_limit(1280, 1280, false) == 1280); + TEST_ASSERT(decode_soft_token_limit(1280, 1280, true) == 1024); + TEST_ASSERT(decode_soft_token_limit(128, 1280, true) == 128); + TEST_ASSERT(decode_soft_token_limit(200, 200, true) == 200); + TEST_ASSERT(decode_soft_token_limit(300, 300, true) == 44); + + TEST_ASSERT(!decode_tool_call_open(false, false, DSML_DECODE_OUTSIDE)); + TEST_ASSERT(decode_tool_call_open(true, false, DSML_DECODE_OUTSIDE)); + TEST_ASSERT(!decode_tool_call_open(true, true, DSML_DECODE_OUTSIDE)); + TEST_ASSERT(decode_tool_call_open(false, false, DSML_DECODE_STRUCTURAL)); + TEST_ASSERT(dsml_suffix_partial_tool_start("<", 1)); + TEST_ASSERT(dsml_suffix_partial_tool_start("\n\n" DS4_TOOL_CALLS_START_SHORT "\n<", strlen("\n\n" DS4_TOOL_CALLS_START_SHORT "\n<"))); + TEST_ASSERT(!dsml_suffix_partial_tool_start(DS4_TOOL_CALLS_START, strlen(DS4_TOOL_CALLS_START))); + TEST_ASSERT(!dsml_suffix_partial_tool_start("\n<|DSML|parameter name=\"command\" string=\"true\">a\n\n"; const char *b_dsml = "\n\n<|DSML|tool_calls>\n<|DSML|invoke name=\"bash\">\n<|DSML|parameter name=\"command\" string=\"true\">b\n\n"; @@ -10307,6 +10417,7 @@ static void ds4_server_unit_tests_run(void) { test_tool_memory_replays_sampled_dsml(); test_exact_dsml_tool_replay_can_be_disabled(); test_dsml_decode_state_separates_structure_and_payload(); + test_decode_budget_reserves_tool_call_room(); test_tool_memory_max_ids_prunes_oldest(); test_kv_tool_map_filters_by_dsml_text(); test_kv_tool_map_restores_before_prompt_render();