Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 145 additions & 15 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <math.h>
#include <netinet/in.h>
#include <poll.h>
#include <pthread.h>
Expand Down Expand Up @@ -214,9 +215,11 @@ static bool json_string(const char **p, char **out) {
while (**p && **p != '"') {
unsigned char c = (unsigned char)*(*p)++;
if (c != '\\') {
if (c < 0x20) goto fail;
buf_putc(&b, (char)c);
continue;
}
if (!**p) goto fail;
c = (unsigned char)*(*p)++;
switch (c) {
case '"': buf_putc(&b, '"'); break;
Expand All @@ -231,8 +234,11 @@ static bool json_string(const char **p, char **out) {
*p -= 2;
uint32_t cp = 0, lo = 0;
if (!json_u16(p, &cp)) goto fail;
if (cp >= 0xd800 && cp <= 0xdbff && json_u16(p, &lo) && lo >= 0xdc00 && lo <= 0xdfff) {
if (cp >= 0xd800 && cp <= 0xdbff) {
if (!json_u16(p, &lo) || lo < 0xdc00 || lo > 0xdfff) goto fail;
cp = 0x10000u + ((cp - 0xd800u) << 10) + (lo - 0xdc00u);
} else if (cp >= 0xdc00 && cp <= 0xdfff) {
goto fail;
}
utf8_put(&b, cp);
break;
Expand All @@ -250,11 +256,47 @@ static bool json_string(const char **p, char **out) {
return false;
}

static const char *json_number_end(const char *p) {
const char *start = p;
if (*p == '-') p++;

if (*p == '0') {
p++;
if (isdigit((unsigned char)*p)) return NULL;
} else {
if (!isdigit((unsigned char)*p)) return NULL;
while (isdigit((unsigned char)*p)) p++;
}

if (*p == '.') {
p++;
if (!isdigit((unsigned char)*p)) return NULL;
while (isdigit((unsigned char)*p)) p++;
}

if (*p == 'e' || *p == 'E') {
p++;
if (*p == '+' || *p == '-') p++;
if (!isdigit((unsigned char)*p)) return NULL;
while (isdigit((unsigned char)*p)) p++;
}

return p > start ? p : NULL;
}

static bool json_number(const char **p, double *out) {
json_ws(p);
char *end = NULL;
double v = strtod(*p, &end);
if (end == *p) return false;
const char *end = json_number_end(*p);
if (!end) return false;

char *num = xstrndup(*p, (size_t)(end - *p));
char *parsed_end = NULL;
errno = 0;
double v = strtod(num, &parsed_end);
bool ok = parsed_end && *parsed_end == '\0' && errno != ERANGE && isfinite(v);
free(num);
if (!ok) return false;

*p = end;
*out = v;
return true;
Expand Down Expand Up @@ -2702,17 +2744,39 @@ static char *dsml_unescape_text(const char *s) {
}

static char *dsml_attr(const char *tag, const char *name) {
char pat[64];
snprintf(pat, sizeof(pat), "%s=\"", name);
const char *p = strstr(tag, pat);
if (!p) return NULL;
p += strlen(pat);
const char *q = strchr(p, '"');
if (!q) return NULL;
char *raw = xstrndup(p, (size_t)(q - p));
char *decoded = dsml_unescape_text(raw);
free(raw);
return decoded;
size_t name_len = strlen(name);
const char *p = tag;

if (*p == '<') p++;
while (*p && !isspace((unsigned char)*p) && *p != '>' && *p != '/') p++;

while (*p) {
while (*p && isspace((unsigned char)*p)) p++;
if (!*p || *p == '>' || (*p == '/' && p[1] == '>')) return NULL;

const char *attr = p;
while (*p && !isspace((unsigned char)*p) && *p != '=' && *p != '>' && *p != '/') p++;
const char *attr_end = p;
while (*p && isspace((unsigned char)*p)) p++;
if (*p != '=') return NULL;
p++;
while (*p && isspace((unsigned char)*p)) p++;
if (*p != '"') return NULL;
p++;

const char *value = p;
while (*p && *p != '"') p++;
if (*p != '"') return NULL;

if ((size_t)(attr_end - attr) == name_len && !memcmp(attr, name, name_len)) {
char *raw = xstrndup(value, (size_t)(p - value));
char *decoded = dsml_unescape_text(raw);
free(raw);
return decoded;
}
p++;
}
return NULL;
}

static void tool_call_json_args_add(buf *args, const char *name, const char *value, const char *is_string) {
Expand Down Expand Up @@ -9548,6 +9612,69 @@ static void test_stop_list_streaming_holds_and_trims_stop_text(void) {
free(stops.v);
}

static void test_json_number_rejects_non_json_forms(void) {
const char *p = "0";
double v = 0.0;
TEST_ASSERT(json_number(&p, &v));
TEST_ASSERT(v == 0.0);
TEST_ASSERT(*p == '\0');

p = "-12.5e+2";
TEST_ASSERT(json_number(&p, &v));
TEST_ASSERT(v == -1250.0);
TEST_ASSERT(*p == '\0');

p = "NaN";
TEST_ASSERT(!json_number(&p, &v));

p = "Infinity";
TEST_ASSERT(!json_number(&p, &v));

p = "+1";
TEST_ASSERT(!json_number(&p, &v));

p = "01";
TEST_ASSERT(!json_number(&p, &v));
}

static void test_json_string_rejects_raw_control_chars(void) {
const char *p = "\"line1\\nline2\"";
char *s = NULL;
static const char raw_newline[] = {'"', 'l', 'i', 'n', 'e', '1', '\n', 'l', 'i', 'n', 'e', '2', '"', '\0'};
static const char raw_tab[] = {'"', 't', 'a', 'b', '\t', 'i', 'n', 's', 'i', 'd', 'e', '"', '\0'};
TEST_ASSERT(json_string(&p, &s));
TEST_ASSERT(!strcmp(s, "line1\nline2"));
TEST_ASSERT(*p == '\0');
free(s);

p = raw_newline;
s = NULL;
TEST_ASSERT(!json_string(&p, &s));
free(s);

p = raw_tab;
s = NULL;
TEST_ASSERT(!json_string(&p, &s));
free(s);
}

static void test_dsml_parser_rejects_prefixed_attribute_names(void) {
const char *generated =
DS4_TOOL_CALLS_START "\n"
DS4_INVOKE_START " xname=\"bash\">\n"
DS4_PARAM_START " xname=\"command\" string=\"true\">pwd" DS4_PARAM_END "\n"
DS4_INVOKE_END "\n"
DS4_TOOL_CALLS_END;

char *content = NULL;
char *reasoning = NULL;
tool_calls calls = {0};
TEST_ASSERT(!parse_generated_message(generated, &content, &reasoning, &calls));
free(content);
free(reasoning);
tool_calls_free(&calls);
}

static char *test_nested_json_array(int depth) {
buf b = {0};
for (int i = 0; i < depth; i++) buf_putc(&b, '[');
Expand Down Expand Up @@ -10319,6 +10446,9 @@ static void ds4_server_unit_tests_run(void) {
test_dsml_prompt_escapes_tool_supplied_text();
test_stop_list_parses_all_sequences();
test_stop_list_streaming_holds_and_trims_stop_text();
test_json_number_rejects_non_json_forms();
test_json_string_rejects_raw_control_chars();
test_dsml_parser_rejects_prefixed_attribute_names();
test_json_skip_has_nesting_limit();
test_model_metadata_clamps_completion_to_context();
test_client_socket_nonblocking_flag();
Expand Down