Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 61 additions & 33 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -4501,7 +4501,14 @@ static void append_tool_call_deltas_json(buf *b, const tool_calls *calls, const
buf_putc(b, ']');
}

static bool http_response(int fd, int code, const char *type, const char *body) {
static void append_cors_headers(buf *h) {
buf_puts(h,
"Access-Control-Allow-Origin: *\r\n"
"Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n"
"Access-Control-Allow-Headers: *\r\n");
}

static bool http_response(int fd, bool enable_cors, int code, const char *type, const char *body) {
const char *reason = code == 200 ? "OK" :
code == 400 ? "Bad Request" :
code == 404 ? "Not Found" :
Expand All @@ -4510,34 +4517,39 @@ static bool http_response(int fd, int code, const char *type, const char *body)
buf_printf(&h,
"HTTP/1.1 %d %s\r\n"
"Content-Type: %s\r\n"
"Content-Length: %zu\r\n"
"Connection: close\r\n\r\n",
"Content-Length: %zu\r\n",
code, reason, type, strlen(body));
if (enable_cors) append_cors_headers(&h);
buf_puts(&h, "Connection: close\r\n\r\n");
bool ok = send_all(fd, h.ptr, h.len) && send_all(fd, body, strlen(body));
buf_free(&h);
return ok;
}

static bool http_error(int fd, int code, const char *msg) {
static bool http_error(int fd, bool enable_cors, int code, const char *msg) {
buf b = {0};
buf_puts(&b, "{\"error\":{\"message\":");
json_escape(&b, msg);
buf_puts(&b, ",\"type\":\"invalid_request_error\"}}\n");
bool ok = http_response(fd, code, "application/json", b.ptr);
bool ok = http_response(fd, enable_cors, code, "application/json", b.ptr);
buf_free(&b);
return ok;
}

/* Streaming is a translation state machine over the raw DS4 text. The model
* may produce <think> and DSML tool blocks; clients should receive those as
* protocol-native reasoning/tool deltas, never as visible assistant text. */
static bool sse_headers(int fd) {
const char *h =
static bool sse_headers(int fd, bool enable_cors) {
buf h = {0};
buf_puts(&h,
"HTTP/1.1 200 OK\r\n"
"Content-Type: text/event-stream\r\n"
"Cache-Control: no-cache\r\n"
"Connection: close\r\n\r\n";
return send_all(fd, h, strlen(h));
"Cache-Control: no-cache\r\n");
if (enable_cors) append_cors_headers(&h);
buf_puts(&h, "Connection: close\r\n\r\n");
bool ok = send_all(fd, h.ptr, h.len);
buf_free(&h);
return ok;
}

static bool sse_chunk(int fd, const request *r, const char *id, const char *text, const char *finish) {
Expand Down Expand Up @@ -6310,7 +6322,7 @@ static bool responses_sse_finish_live(int fd, const request *r,
return ok;
}

static bool responses_final_response(int fd, const request *r, const char *id,
static bool responses_final_response(int fd, bool enable_cors, const request *r, const char *id,
const char *text, const char *reasoning,
const tool_calls *calls, const char *finish,
int prompt_tokens, int completion_tokens) {
Expand Down Expand Up @@ -6379,13 +6391,13 @@ static bool responses_final_response(int fd, const request *r, const char *id,
"\"output_tokens\":%d,\"output_tokens_details\":{\"reasoning_tokens\":0},"
"\"total_tokens\":%d}}",
prompt_tokens, completion_tokens, prompt_tokens + completion_tokens);
bool ok = http_response(fd, 200, "application/json", b.ptr);
bool ok = http_response(fd, enable_cors, 200, "application/json", b.ptr);
buf_free(&b);
free(items);
return ok;
}

static bool final_response(int fd, const request *r, const char *id, const char *text,
static bool final_response(int fd, bool enable_cors, const request *r, const char *id, const char *text,
const char *reasoning, const tool_calls *calls, const char *finish,
int prompt_tokens, int completion_tokens) {
buf b = {0};
Expand Down Expand Up @@ -6417,7 +6429,7 @@ static bool final_response(int fd, const request *r, const char *id, const char
}
buf_printf(&b, "{\"prompt_tokens\":%d,\"completion_tokens\":%d,\"total_tokens\":%d}}\n",
prompt_tokens, completion_tokens, prompt_tokens + completion_tokens);
bool ok = http_response(fd, 200, "application/json", b.ptr);
bool ok = http_response(fd, enable_cors, 200, "application/json", b.ptr);
buf_free(&b);
return ok;
}
Expand Down Expand Up @@ -6483,7 +6495,7 @@ static void append_anthropic_content(buf *b, const char *text, const char *reaso
buf_putc(b, ']');
}

static bool anthropic_final_response(int fd, const request *r, const char *id, const char *text,
static bool anthropic_final_response(int fd, bool enable_cors, const request *r, const char *id, const char *text,
const char *reasoning, const tool_calls *calls, const char *finish,
int prompt_tokens, int completion_tokens) {
buf b = {0};
Expand All @@ -6496,7 +6508,7 @@ static bool anthropic_final_response(int fd, const request *r, const char *id, c
buf_puts(&b, ",\"stop_sequence\":null,\"usage\":");
buf_printf(&b, "{\"input_tokens\":%d,\"output_tokens\":%d}}\n",
prompt_tokens, completion_tokens);
bool ok = http_response(fd, 200, "application/json", b.ptr);
bool ok = http_response(fd, enable_cors, 200, "application/json", b.ptr);
buf_free(&b);
return ok;
}
Expand Down Expand Up @@ -6996,6 +7008,7 @@ struct server {
FILE *trace;
pthread_mutex_t trace_mu;
uint64_t trace_seq;
bool enable_cors;
};

/* Jobs are stack-owned by the client thread. The worker signals completion
Expand Down Expand Up @@ -9506,7 +9519,7 @@ static void generate_job(server *s, job *j) {
* the prior assistant call, there is no stateless prefix to match and
* no disk key to search by. */
ds4_tokens_free(&effective_prompt);
http_error(j->fd, 400,
http_error(j->fd, s->enable_cors, 400,
"Responses tool output requires live call state; replay full input instead");
return;
} else {
Expand Down Expand Up @@ -9575,7 +9588,7 @@ static void generate_job(server *s, job *j) {
* reasoning are not enough. */
ds4_tokens_free(&effective_prompt);
free(disk_cache_path);
http_error(j->fd, 400,
http_error(j->fd, s->enable_cors, 400,
"Responses replay is missing reasoning state; retry with live state or full reasoning items");
return;
}
Expand Down Expand Up @@ -9639,7 +9652,7 @@ static void generate_job(server *s, job *j) {
ds4_tokens_free(&effective_prompt);
ds4_session_set_progress(s->session, NULL, NULL);
trace_event(s, trace_id, "prefill failed: %s", err);
http_error(j->fd, 500, err);
http_error(j->fd, s->enable_cors, 500, err);
return;
}
if (kv_cache_store_live_prefix(s, prompt_for_sync, cold_store_len, "cold")) {
Expand All @@ -9652,7 +9665,7 @@ static void generate_job(server *s, job *j) {
ds4_tokens_free(&effective_prompt);
ds4_session_set_progress(s->session, NULL, NULL);
trace_event(s, trace_id, "prefill failed: %s", err);
http_error(j->fd, 500, err);
http_error(j->fd, s->enable_cors, 500, err);
return;
}
/* Once a non-live request wins, the old Responses live binding is stale.
Expand Down Expand Up @@ -9685,7 +9698,7 @@ static void generate_job(server *s, job *j) {
const bool responses_live_chat = request_uses_responses_live_stream(&j->req);
long responses_created_at = (long)time(NULL);
if (j->req.stream) {
if (!sse_headers(j->fd)) {
if (!sse_headers(j->fd, s->enable_cors)) {
server_log(DS4_LOG_GENERATION, "ds4-server: %s ctx=%s sse headers failed", j->req.kind == REQ_CHAT ? "chat" : "completion", ctx_span);
ds4_tokens_free(&effective_prompt);
return;
Expand Down Expand Up @@ -10099,19 +10112,19 @@ static void generate_job(server *s, job *j) {
ctx_span);
}
} else if (j->req.api == API_ANTHROPIC) {
anthropic_final_response(j->fd, &j->req, id,
anthropic_final_response(j->fd, s->enable_cors, &j->req, id,
parsed_content ? parsed_content : (text.ptr ? text.ptr : ""),
parsed_reasoning,
&parsed_calls, final_finish,
prompt_tokens, completion);
} else if (j->req.api == API_RESPONSES) {
responses_final_response(j->fd, &j->req, id,
responses_final_response(j->fd, s->enable_cors, &j->req, id,
parsed_content ? parsed_content : (text.ptr ? text.ptr : ""),
parsed_reasoning,
&parsed_calls, final_finish,
prompt_tokens, completion);
} else {
final_response(j->fd, &j->req, id,
final_response(j->fd, s->enable_cors, &j->req, id,
parsed_content ? parsed_content : (text.ptr ? text.ptr : ""),
parsed_reasoning,
&parsed_calls, final_finish,
Expand Down Expand Up @@ -10355,7 +10368,7 @@ static bool send_model(server *s, int fd) {
buf b = {0};
append_model_json(&b, s);
buf_putc(&b, '\n');
bool ok = http_response(fd, 200, "application/json", b.ptr);
bool ok = http_response(fd, s->enable_cors, 200, "application/json", b.ptr);
buf_free(&b);
return ok;
}
Expand All @@ -10365,7 +10378,7 @@ static bool send_models(server *s, int fd) {
buf_puts(&b, "{\"object\":\"list\",\"data\":[");
append_model_json(&b, s);
buf_puts(&b, "]}\n");
bool ok = http_response(fd, 200, "application/json", b.ptr);
bool ok = http_response(fd, s->enable_cors, 200, "application/json", b.ptr);
buf_free(&b);
return ok;
}
Expand All @@ -10387,7 +10400,13 @@ static void *client_main(void *arg) {

http_request hr = {0};
if (!read_http_request(fd, &hr)) {
http_error(fd, 400, "bad HTTP request");
http_error(fd, s->enable_cors, 400, "bad HTTP request");
goto done;
}

if (!strcmp(hr.method, "OPTIONS")) {
http_response(fd, s->enable_cors, 200, "text/plain", "");
http_request_free(&hr);
goto done;
}

Expand Down Expand Up @@ -10419,14 +10438,14 @@ static void *client_main(void *arg) {
ok = parse_completion_request(s->engine, hr.body, s->default_tokens,
ctx_size, &req, err, sizeof(err));
} else {
http_error(fd, 404, "unknown endpoint");
http_error(fd, s->enable_cors, 404, "unknown endpoint");
http_request_free(&hr);
goto done;
}
if (ok) req.raw_body = xstrndup(hr.body, hr.body_len);
http_request_free(&hr);
if (!ok) {
http_error(fd, 400, err);
http_error(fd, s->enable_cors, 400, err);
goto done;
}

Expand All @@ -10441,7 +10460,7 @@ static void *client_main(void *arg) {
pthread_mutex_lock(&j.mu);
if (!enqueue(s, &j)) {
pthread_mutex_unlock(&j.mu);
http_error(fd, 503, "server shutting down");
http_error(fd, s->enable_cors, 503, "server shutting down");
pthread_cond_destroy(&j.cv);
pthread_mutex_destroy(&j.mu);
request_free(&j.req);
Expand All @@ -10459,7 +10478,7 @@ static void *client_main(void *arg) {
return NULL;
}

static int listen_on(const char *host, int port) {
static int listen_on(const char *host, int port, bool enable_cors) {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) return -1;
int yes = 1;
Expand All @@ -10469,7 +10488,9 @@ static int listen_on(const char *host, int port) {
memset(&sa, 0, sizeof(sa));
sa.sin_family = AF_INET;
sa.sin_port = htons((uint16_t)port);
if (!strcmp(host, "localhost")) host = "127.0.0.1";
if (!strcmp(host, "localhost")) {
host = enable_cors ? "0.0.0.0" : "127.0.0.1";
}
if (inet_pton(AF_INET, host, &sa.sin_addr) != 1) {
close(fd);
errno = EINVAL;
Expand Down Expand Up @@ -10515,6 +10536,7 @@ typedef struct {
bool kv_cache_reject_different_quant;
bool disable_exact_dsml_tool_replay;
int tool_memory_max_ids;
bool enable_cors;
} server_config;

static int parse_int_arg(const char *s, const char *opt) {
Expand Down Expand Up @@ -10623,6 +10645,8 @@ static void usage(FILE *fp) {
" Bind address. Default: 127.0.0.1\n"
" --port N\n"
" Bind port. Default: 8000\n"
" --cors\n"
" Send Access-Control-Allow-* headers on HTTP responses. If host is localhost, bind to 0.0.0.0 instead of 127.0.0.1.\n"
" --trace FILE\n"
" Write a human-readable session trace: prompts, cache decisions, output, tool calls.\n"
"\n"
Expand Down Expand Up @@ -10706,6 +10730,7 @@ static server_config parse_options(int argc, char **argv) {
.ctx_size = 32768,
.default_tokens = 393216,
.tool_memory_max_ids = DS4_TOOL_MEMORY_DEFAULT_MAX_IDS,
.enable_cors = false,
};
c.kv_cache = kv_cache_default_options();

Expand Down Expand Up @@ -10775,6 +10800,8 @@ static server_config parse_options(int argc, char **argv) {
c.engine.backend = parse_backend_arg(need_arg(&i, argc, argv, arg), arg);
} else if (!strcmp(arg, "--cpu")) {
c.engine.backend = DS4_BACKEND_CPU;
} else if (!strcmp(arg, "--cors") || !strcmp(arg, "--CORS")) {
c.enable_cors = true;
} else {
server_log(DS4_LOG_DEFAULT, "ds4-server: unknown option: %s", arg);
usage(stderr);
Expand Down Expand Up @@ -10826,6 +10853,7 @@ int main(int argc, char **argv) {
s.default_tokens = cfg.default_tokens;
s.disable_exact_dsml_tool_replay = cfg.disable_exact_dsml_tool_replay;
s.tool_mem.max_entries = cfg.tool_memory_max_ids;
s.enable_cors = cfg.enable_cors;
if (cfg.kv_disk_dir) {
kv_cache_open(&s.kv, cfg.kv_disk_dir, cfg.kv_disk_space_mb,
cfg.kv_cache_reject_different_quant, cfg.kv_cache);
Expand Down Expand Up @@ -10854,7 +10882,7 @@ int main(int argc, char **argv) {
pthread_t worker;
if (pthread_create(&worker, NULL, worker_main, &s) != 0) die("failed to start worker");

int lfd = listen_on(cfg.host, cfg.port);
int lfd = listen_on(cfg.host, cfg.port, cfg.enable_cors);
if (lfd < 0) {
server_log(DS4_LOG_DEFAULT, "ds4-server: failed to listen on %s:%d: %s", cfg.host, cfg.port, strerror(errno));
pthread_mutex_lock(&s.mu);
Expand Down