From 95364482374f1e9121e2b5cd47955a3fcfee37a9 Mon Sep 17 00:00:00 2001 From: andrey yazev Date: Wed, 13 May 2026 11:33:34 +0200 Subject: [PATCH] cancel prefill on client disconnect --- ds4.c | 43 ++++++++++++++++++++++++++++++++++++++----- ds4.h | 1 + ds4_server.c | 27 +++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/ds4.c b/ds4.c index 51410e33..96ff356a 100644 --- a/ds4.c +++ b/ds4.c @@ -13331,6 +13331,7 @@ static bool metal_graph_prefill_chunked_range( bool show_progress, ds4_session_progress_fn progress, void *progress_ud, + bool *cancel_requested, ds4_imatrix_collector *imatrix) { if (n_tokens == 0 || g->prefill_cap == 0) return false; if (start > (uint32_t)prompt->len) return false; @@ -13360,6 +13361,7 @@ static bool metal_graph_prefill_chunked_range( if (progress) { progress(progress_ud, "prefill_chunk", (int)start, prompt->len); } + if (cancel_requested && *cancel_requested) return false; for (uint32_t pos0 = start; pos0 < end; ) { const uint32_t remaining = end - pos0; @@ -13433,6 +13435,7 @@ static bool metal_graph_prefill_chunked_range( if (progress) { progress(progress_ud, "prefill_chunk", (int)(pos0 + chunk), prompt->len); } + if (cancel_requested && *cancel_requested) return false; pos0 += chunk; } if (show_progress) fputc('\n', stderr); @@ -13488,7 +13491,8 @@ static bool metal_graph_prefill_chunked( float *logits, bool show_progress, ds4_session_progress_fn progress, - void *progress_ud) { + void *progress_ud, + bool *cancel_requested) { if (n_tokens <= 0) return false; return metal_graph_prefill_chunked_range(g, model, @@ -13500,6 +13504,7 @@ static bool metal_graph_prefill_chunked( show_progress, progress, progress_ud, + cancel_requested, NULL); } @@ -15288,7 +15293,7 @@ static int generate_metal_graph_raw_swa( const double t_prefill0 = now_sec(); if (prefill_cap < (uint32_t)prompt->len) { - ok = metal_graph_prefill_chunked(&g, model, weights, prompt, prompt->len, logits, false, progress, progress_ud); + ok = metal_graph_prefill_chunked(&g, model, weights, prompt, prompt->len, logits, false, progress, progress_ud, NULL); } else { ok = metal_graph_prefill_raw_swa(&g, model, weights, prompt, prompt->len, logits, true); } @@ -15513,6 +15518,7 @@ struct ds4_session { uint64_t mtp_probe_hit; ds4_session_progress_fn progress; void *progress_ud; + bool cancel_requested; uint32_t prefill_cap; int ctx_size; bool checkpoint_valid; @@ -16665,6 +16671,7 @@ int ds4_engine_collect_imatrix(ds4_engine *e, (uint32_t)prompt.len, NULL, false, NULL, NULL, + NULL, &collector); } else { ok = metal_graph_prefill_layer_major(&g, model, weights, @@ -17162,6 +17169,11 @@ void ds4_session_set_progress(ds4_session *s, ds4_session_progress_fn fn, void * s->progress_ud = ud; } +void ds4_session_cancel(ds4_session *s) { + if (!s) return; + s->cancel_requested = true; +} + #ifndef DS4_NO_GPU typedef struct { ds4_session *session; @@ -17203,6 +17215,7 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t snprintf(err, errlen, "prompt exceeds context"); return 1; } + s->cancel_requested = false; if (ds4_session_is_cpu(s)) { ds4_engine *e = s->engine; if (s->checkpoint_valid && @@ -17223,6 +17236,11 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t &s->cpu_scratch); token_vec_push(&s->checkpoint, prompt->v[i]); if (s->progress) s->progress(s->progress_ud, "prefill_chunk", i + 1, prompt->len); + if (s->cancel_requested) { + snprintf(err, errlen, "prefill cancelled"); + s->checkpoint_valid = false; + return 1; + } } s->checkpoint_valid = true; return 0; @@ -17241,6 +17259,11 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t s->checkpoint_valid = true; s->mtp_draft_valid = false; if (s->progress) s->progress(s->progress_ud, "prefill_chunk", prompt->len, prompt->len); + if (s->cancel_requested) { + snprintf(err, errlen, "prefill cancelled"); + s->checkpoint_valid = false; + return 1; + } return 0; } #ifdef DS4_NO_GPU @@ -17278,9 +17301,14 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t false, progress_fn, progress_fn ? &progress : NULL, + &s->cancel_requested, NULL); if (!ok) { - snprintf(err, errlen, "%s resumed prefill failed while extending checkpoint", backend_name); + if (s->cancel_requested) { + snprintf(err, errlen, "%s resumed prefill cancelled", backend_name); + } else { + snprintf(err, errlen, "%s resumed prefill failed while extending checkpoint", backend_name); + } s->checkpoint_valid = false; return 1; } @@ -17316,13 +17344,18 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t s->progress ? ds4_session_note_prefill_progress : NULL; ok = metal_graph_prefill_chunked(&s->graph, &e->model, &e->weights, prompt, prompt->len, s->logits, false, - progress_fn, progress_fn ? &progress : NULL); + progress_fn, progress_fn ? &progress : NULL, + &s->cancel_requested); } else { ok = metal_graph_prefill_raw_swa(&s->graph, &e->model, &e->weights, prompt, prompt->len, s->logits, false); } if (!ok) { - snprintf(err, errlen, "%s prefill failed", backend_name); + if (s->cancel_requested) { + snprintf(err, errlen, "%s prefill cancelled", backend_name); + } else { + snprintf(err, errlen, "%s prefill failed", backend_name); + } s->checkpoint_valid = false; return 1; } diff --git a/ds4.h b/ds4.h index 950d8dca..6d702a5f 100644 --- a/ds4.h +++ b/ds4.h @@ -145,6 +145,7 @@ int ds4_token_eos(ds4_engine *e); int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size); void ds4_session_free(ds4_session *s); void ds4_session_set_progress(ds4_session *s, ds4_session_progress_fn fn, void *ud); +void ds4_session_cancel(ds4_session *s); typedef enum { DS4_SESSION_REWRITE_ERROR = -1, diff --git a/ds4_server.c b/ds4_server.c index bc8abbbd..eadd48ef 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -6534,8 +6534,10 @@ typedef struct { req_kind kind; int prompt_tokens; int cached_tokens; + int fd; char ctx[48]; bool has_tools; + bool is_cancelled; double t0; double last_t; int last_current; @@ -6650,6 +6652,28 @@ static void log_tool_calls_summary(const char *ctx, const tool_calls *calls) { buf_free(&names); } +static void server_prefill_check_client(server_prefill_progress *p) { + if (!p || p->is_cancelled || p->fd < 0) return; + struct pollfd pfd = {.fd = p->fd, .events = POLLIN}; + int rc; + do { + rc = poll(&pfd, 1, 0); + } while (rc < 0 && errno == EINTR); + bool is_gone = rc > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLNVAL)); + if (!is_gone && rc > 0 && (pfd.revents & POLLIN)) { + char c; + ssize_t n = recv(p->fd, &c, 1, MSG_PEEK); + is_gone = n == 0 || (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR); + } + if (!is_gone) return; + p->is_cancelled = true; + if (p->srv) ds4_session_cancel(p->srv->session); + server_log(DS4_LOG_PREFILL, + "ds4-server: %s ctx=%s prefill cancelled: client disconnected", + p->kind == REQ_CHAT ? "chat" : "completion", + p->ctx); +} + static void server_progress_cb(void *ud, const char *event, int current, int total) { server_prefill_progress *p = ud; if (!p || !event || strcmp(event, "prefill_chunk")) return; @@ -6658,6 +6682,7 @@ static void server_progress_cb(void *ud, const char *event, int current, int tot double elapsed = now - p->t0; if (p->seen && current == p->last_current) { if (p->srv && current > p->cached_tokens) kv_cache_maybe_store_continued(p->srv); + server_prefill_check_client(p); return; } int display_start = p->cached_tokens; @@ -6694,6 +6719,7 @@ static void server_progress_cb(void *ud, const char *event, int current, int tot avg_tps, elapsed); if (p->srv && current > p->cached_tokens) kv_cache_maybe_store_continued(p->srv); + server_prefill_check_client(p); } static char *build_tool_checkpoint_suffix(const request *r, const char *content, @@ -6986,6 +7012,7 @@ static void generate_job(server *s, job *j) { .kind = j->req.kind, .prompt_tokens = prompt_tokens, .cached_tokens = cached, + .fd = j->fd, .has_tools = j->req.has_tools, .t0 = t0, };