diff --git a/ds4_server.c b/ds4_server.c index 0ae9767..297a9f7 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -4633,7 +4633,14 @@ static void append_tool_call_deltas_json(buf *b, const tool_calls *calls, const buf_putc(b, ']'); } -static bool http_response(int fd, int code, const char *type, const char *body) { +static void append_cors_headers(buf *h) { + buf_puts(h, + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n" + "Access-Control-Allow-Headers: *\r\n"); +} + +static bool http_response(int fd, bool enable_cors, int code, const char *type, const char *body) { const char *reason = code == 200 ? "OK" : code == 400 ? "Bad Request" : code == 404 ? "Not Found" : @@ -4643,20 +4650,21 @@ static bool http_response(int fd, int code, const char *type, const char *body) buf_printf(&h, "HTTP/1.1 %d %s\r\n" "Content-Type: %s\r\n" - "Content-Length: %zu\r\n" - "Connection: close\r\n\r\n", + "Content-Length: %zu\r\n", code, reason, type, strlen(body)); + if (enable_cors) append_cors_headers(&h); + buf_puts(&h, "Connection: close\r\n\r\n"); bool ok = send_all(fd, h.ptr, h.len) && send_all(fd, body, strlen(body)); buf_free(&h); return ok; } -static bool http_error(int fd, int code, const char *msg) { +static bool http_error(int fd, bool enable_cors, int code, const char *msg) { buf b = {0}; buf_puts(&b, "{\"error\":{\"message\":"); json_escape(&b, msg); buf_puts(&b, ",\"type\":\"invalid_request_error\"}}\n"); - bool ok = http_response(fd, code, "application/json", b.ptr); + bool ok = http_response(fd, enable_cors, code, "application/json", b.ptr); buf_free(&b); return ok; } @@ -4664,13 +4672,17 @@ static bool http_error(int fd, int code, const char *msg) { /* Streaming is a translation state machine over the raw DS4 text. The model * may produce and DSML tool blocks; clients should receive those as * protocol-native reasoning/tool deltas, never as visible assistant text. */ -static bool sse_headers(int fd) { - const char *h = +static bool sse_headers(int fd, bool enable_cors) { + buf h = {0}; + buf_puts(&h, "HTTP/1.1 200 OK\r\n" "Content-Type: text/event-stream\r\n" - "Cache-Control: no-cache\r\n" - "Connection: close\r\n\r\n"; - return send_all(fd, h, strlen(h)); + "Cache-Control: no-cache\r\n"); + if (enable_cors) append_cors_headers(&h); + buf_puts(&h, "Connection: close\r\n\r\n"); + bool ok = send_all(fd, h.ptr, h.len); + buf_free(&h); + return ok; } static bool sse_chunk(int fd, const request *r, const char *id, const char *text, const char *finish) { @@ -6446,7 +6458,7 @@ static bool responses_sse_finish_live(int fd, const request *r, return ok; } -static bool responses_final_response(int fd, const request *r, const char *id, +static bool responses_final_response(int fd, bool enable_cors, const request *r, const char *id, const char *text, const char *reasoning, const tool_calls *calls, const char *finish, int prompt_tokens, int completion_tokens) { @@ -6515,13 +6527,13 @@ static bool responses_final_response(int fd, const request *r, const char *id, "\"output_tokens\":%d,\"output_tokens_details\":{\"reasoning_tokens\":0}," "\"total_tokens\":%d}}", prompt_tokens, completion_tokens, prompt_tokens + completion_tokens); - bool ok = http_response(fd, 200, "application/json", b.ptr); + bool ok = http_response(fd, enable_cors, 200, "application/json", b.ptr); buf_free(&b); free(items); return ok; } -static bool final_response(int fd, const request *r, const char *id, const char *text, +static bool final_response(int fd, bool enable_cors, const request *r, const char *id, const char *text, const char *reasoning, const tool_calls *calls, const char *finish, int prompt_tokens, int completion_tokens) { buf b = {0}; @@ -6553,7 +6565,7 @@ static bool final_response(int fd, const request *r, const char *id, const char } buf_printf(&b, "{\"prompt_tokens\":%d,\"completion_tokens\":%d,\"total_tokens\":%d}}\n", prompt_tokens, completion_tokens, prompt_tokens + completion_tokens); - bool ok = http_response(fd, 200, "application/json", b.ptr); + bool ok = http_response(fd, enable_cors, 200, "application/json", b.ptr); buf_free(&b); return ok; } @@ -6619,7 +6631,7 @@ static void append_anthropic_content(buf *b, const char *text, const char *reaso buf_putc(b, ']'); } -static bool anthropic_final_response(int fd, const request *r, const char *id, const char *text, +static bool anthropic_final_response(int fd, bool enable_cors, const request *r, const char *id, const char *text, const char *reasoning, const tool_calls *calls, const char *finish, int prompt_tokens, int completion_tokens) { buf b = {0}; @@ -6632,7 +6644,7 @@ static bool anthropic_final_response(int fd, const request *r, const char *id, c buf_puts(&b, ",\"stop_sequence\":null,\"usage\":"); buf_printf(&b, "{\"input_tokens\":%d,\"output_tokens\":%d}}\n", prompt_tokens, completion_tokens); - bool ok = http_response(fd, 200, "application/json", b.ptr); + bool ok = http_response(fd, enable_cors, 200, "application/json", b.ptr); buf_free(&b); return ok; } @@ -7482,6 +7494,7 @@ struct server { FILE *trace; pthread_mutex_t trace_mu; uint64_t trace_seq; + bool enable_cors; }; /* Jobs are stack-owned by the client thread. The worker signals completion @@ -10175,14 +10188,14 @@ static void generate_job(server *s, job *j) { * the prior assistant call, there is no stateless prefix to match and * no disk key to search by. */ ds4_tokens_free(&effective_prompt); - http_error(j->fd, 409, + http_error(j->fd, s->enable_cors, 409, "Responses continuation state is not available; retry by replaying the full input history"); return; } else if (cached == 0 && j->req.api == API_ANTHROPIC && j->req.anthropic_requires_live_tool_state) { ds4_tokens_free(&effective_prompt); - http_error(j->fd, 409, + http_error(j->fd, s->enable_cors, 409, "Anthropic continuation state is not available; retry by replaying the full messages history"); return; } else if (cached == 0) { @@ -10330,7 +10343,7 @@ static void generate_job(server *s, job *j) { ds4_tokens_free(&effective_prompt); ds4_session_set_progress(s->session, NULL, NULL); trace_event(s, trace_id, "prefill failed: %s", err); - http_error(j->fd, 500, err); + http_error(j->fd, s->enable_cors, 500, err); return; } if (kv_cache_store_live_prefix(s, prompt_for_sync, cold_store_len, "cold")) { @@ -10343,7 +10356,7 @@ static void generate_job(server *s, job *j) { ds4_tokens_free(&effective_prompt); ds4_session_set_progress(s->session, NULL, NULL); trace_event(s, trace_id, "prefill failed: %s", err); - http_error(j->fd, 500, err); + http_error(j->fd, s->enable_cors, 500, err); return; } /* Once a non-live request wins, old protocol live bindings are stale. Keep @@ -10378,7 +10391,7 @@ static void generate_job(server *s, job *j) { const bool responses_live_chat = request_uses_responses_live_stream(&j->req); long responses_created_at = (long)time(NULL); if (j->req.stream) { - if (!sse_headers(j->fd)) { + if (!sse_headers(j->fd, s->enable_cors)) { server_log(DS4_LOG_GENERATION, "ds4-server: %s ctx=%s%s%s sse headers failed", j->req.kind == REQ_CHAT ? "chat" : "completion", @@ -10821,19 +10834,19 @@ static void generate_job(server *s, job *j) { req_flags); } } else if (j->req.api == API_ANTHROPIC) { - anthropic_final_response(j->fd, &j->req, id, + anthropic_final_response(j->fd, s->enable_cors, &j->req, id, parsed_content ? parsed_content : (text.ptr ? text.ptr : ""), parsed_reasoning, &parsed_calls, final_finish, prompt_tokens, completion); } else if (j->req.api == API_RESPONSES) { - responses_final_response(j->fd, &j->req, id, + responses_final_response(j->fd, s->enable_cors, &j->req, id, parsed_content ? parsed_content : (text.ptr ? text.ptr : ""), parsed_reasoning, &parsed_calls, final_finish, prompt_tokens, completion); } else { - final_response(j->fd, &j->req, id, + final_response(j->fd, s->enable_cors, &j->req, id, parsed_content ? parsed_content : (text.ptr ? text.ptr : ""), parsed_reasoning, &parsed_calls, final_finish, @@ -11080,7 +11093,7 @@ static bool send_model(server *s, int fd) { buf b = {0}; append_model_json(&b, s); buf_putc(&b, '\n'); - bool ok = http_response(fd, 200, "application/json", b.ptr); + bool ok = http_response(fd, s->enable_cors, 200, "application/json", b.ptr); buf_free(&b); return ok; } @@ -11090,7 +11103,7 @@ static bool send_models(server *s, int fd) { buf_puts(&b, "{\"object\":\"list\",\"data\":["); append_model_json(&b, s); buf_puts(&b, "]}\n"); - bool ok = http_response(fd, 200, "application/json", b.ptr); + bool ok = http_response(fd, s->enable_cors, 200, "application/json", b.ptr); buf_free(&b); return ok; } @@ -11112,7 +11125,13 @@ static void *client_main(void *arg) { http_request hr = {0}; if (!read_http_request(fd, &hr)) { - http_error(fd, 400, "bad HTTP request"); + http_error(fd, s->enable_cors, 400, "bad HTTP request"); + goto done; + } + + if (!strcmp(hr.method, "OPTIONS")) { + http_response(fd, s->enable_cors, 200, "text/plain", ""); + http_request_free(&hr); goto done; } @@ -11144,14 +11163,14 @@ static void *client_main(void *arg) { ok = parse_completion_request(s->engine, hr.body, s->default_tokens, ctx_size, &req, err, sizeof(err)); } else { - http_error(fd, 404, "unknown endpoint"); + http_error(fd, s->enable_cors, 404, "unknown endpoint"); http_request_free(&hr); goto done; } if (ok) req.raw_body = xstrndup(hr.body, hr.body_len); http_request_free(&hr); if (!ok) { - http_error(fd, 400, err); + http_error(fd, s->enable_cors, 400, err); goto done; } @@ -11166,7 +11185,7 @@ static void *client_main(void *arg) { pthread_mutex_lock(&j.mu); if (!enqueue(s, &j)) { pthread_mutex_unlock(&j.mu); - http_error(fd, 503, "server shutting down"); + http_error(fd, s->enable_cors, 503, "server shutting down"); pthread_cond_destroy(&j.cv); pthread_mutex_destroy(&j.mu); request_free(&j.req); @@ -11184,7 +11203,7 @@ static void *client_main(void *arg) { return NULL; } -static int listen_on(const char *host, int port) { +static int listen_on(const char *host, int port, bool enable_cors) { int fd = socket(AF_INET, SOCK_STREAM, 0); if (fd < 0) return -1; int yes = 1; @@ -11194,7 +11213,9 @@ static int listen_on(const char *host, int port) { memset(&sa, 0, sizeof(sa)); sa.sin_family = AF_INET; sa.sin_port = htons((uint16_t)port); - if (!strcmp(host, "localhost")) host = "127.0.0.1"; + if (!strcmp(host, "localhost")) { + host = enable_cors ? "0.0.0.0" : "127.0.0.1"; + } if (inet_pton(AF_INET, host, &sa.sin_addr) != 1) { close(fd); errno = EINVAL; @@ -11240,6 +11261,7 @@ typedef struct { bool kv_cache_reject_different_quant; bool disable_exact_dsml_tool_replay; int tool_memory_max_ids; + bool enable_cors; } server_config; static int parse_int_arg(const char *s, const char *opt) { @@ -11349,6 +11371,8 @@ static void usage(FILE *fp) { " Bind address. Default: 127.0.0.1\n" " --port N\n" " Bind port. Default: 8000\n" + " --cors\n" + " Send Access-Control-Allow-* headers on HTTP responses. If host is localhost, bind to 0.0.0.0 instead of 127.0.0.1.\n" " --trace FILE\n" " Write a human-readable session trace: prompts, cache decisions, output, tool calls.\n" "\n" @@ -11432,6 +11456,7 @@ static server_config parse_options(int argc, char **argv) { .ctx_size = 32768, .default_tokens = 393216, .tool_memory_max_ids = DS4_TOOL_MEMORY_DEFAULT_MAX_IDS, + .enable_cors = false, }; c.kv_cache = kv_cache_default_options(); @@ -11501,6 +11526,8 @@ static server_config parse_options(int argc, char **argv) { c.engine.backend = parse_backend_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--cpu")) { c.engine.backend = DS4_BACKEND_CPU; + } else if (!strcmp(arg, "--cors") || !strcmp(arg, "--CORS")) { + c.enable_cors = true; } else { server_log(DS4_LOG_DEFAULT, "ds4-server: unknown option: %s", arg); usage(stderr); @@ -11552,6 +11579,7 @@ int main(int argc, char **argv) { s.default_tokens = cfg.default_tokens; s.disable_exact_dsml_tool_replay = cfg.disable_exact_dsml_tool_replay; s.tool_mem.max_entries = cfg.tool_memory_max_ids; + s.enable_cors = cfg.enable_cors; if (cfg.kv_disk_dir) { kv_cache_open(&s.kv, cfg.kv_disk_dir, cfg.kv_disk_space_mb, cfg.kv_cache_reject_different_quant, cfg.kv_cache); @@ -11580,7 +11608,7 @@ int main(int argc, char **argv) { pthread_t worker; if (pthread_create(&worker, NULL, worker_main, &s) != 0) die("failed to start worker"); - int lfd = listen_on(cfg.host, cfg.port); + int lfd = listen_on(cfg.host, cfg.port, cfg.enable_cors); if (lfd < 0) { server_log(DS4_LOG_DEFAULT, "ds4-server: failed to listen on %s:%d: %s", cfg.host, cfg.port, strerror(errno)); pthread_mutex_lock(&s.mu);