Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions ds4.c
Original file line number Diff line number Diff line change
Expand Up @@ -13331,6 +13331,7 @@ static bool metal_graph_prefill_chunked_range(
bool show_progress,
ds4_session_progress_fn progress,
void *progress_ud,
bool *cancel_requested,
ds4_imatrix_collector *imatrix) {
if (n_tokens == 0 || g->prefill_cap == 0) return false;
if (start > (uint32_t)prompt->len) return false;
Expand Down Expand Up @@ -13360,6 +13361,7 @@ static bool metal_graph_prefill_chunked_range(
if (progress) {
progress(progress_ud, "prefill_chunk", (int)start, prompt->len);
}
if (cancel_requested && *cancel_requested) return false;

for (uint32_t pos0 = start; pos0 < end; ) {
const uint32_t remaining = end - pos0;
Expand Down Expand Up @@ -13433,6 +13435,7 @@ static bool metal_graph_prefill_chunked_range(
if (progress) {
progress(progress_ud, "prefill_chunk", (int)(pos0 + chunk), prompt->len);
}
if (cancel_requested && *cancel_requested) return false;
pos0 += chunk;
}
if (show_progress) fputc('\n', stderr);
Expand Down Expand Up @@ -13488,7 +13491,8 @@ static bool metal_graph_prefill_chunked(
float *logits,
bool show_progress,
ds4_session_progress_fn progress,
void *progress_ud) {
void *progress_ud,
bool *cancel_requested) {
if (n_tokens <= 0) return false;
return metal_graph_prefill_chunked_range(g,
model,
Expand All @@ -13500,6 +13504,7 @@ static bool metal_graph_prefill_chunked(
show_progress,
progress,
progress_ud,
cancel_requested,
NULL);
}

Expand Down Expand Up @@ -15288,7 +15293,7 @@ static int generate_metal_graph_raw_swa(

const double t_prefill0 = now_sec();
if (prefill_cap < (uint32_t)prompt->len) {
ok = metal_graph_prefill_chunked(&g, model, weights, prompt, prompt->len, logits, false, progress, progress_ud);
ok = metal_graph_prefill_chunked(&g, model, weights, prompt, prompt->len, logits, false, progress, progress_ud, NULL);
} else {
ok = metal_graph_prefill_raw_swa(&g, model, weights, prompt, prompt->len, logits, true);
}
Expand Down Expand Up @@ -15513,6 +15518,7 @@ struct ds4_session {
uint64_t mtp_probe_hit;
ds4_session_progress_fn progress;
void *progress_ud;
bool cancel_requested;
uint32_t prefill_cap;
int ctx_size;
bool checkpoint_valid;
Expand Down Expand Up @@ -16665,6 +16671,7 @@ int ds4_engine_collect_imatrix(ds4_engine *e,
(uint32_t)prompt.len,
NULL, false,
NULL, NULL,
NULL,
&collector);
} else {
ok = metal_graph_prefill_layer_major(&g, model, weights,
Expand Down Expand Up @@ -17162,6 +17169,11 @@ void ds4_session_set_progress(ds4_session *s, ds4_session_progress_fn fn, void *
s->progress_ud = ud;
}

void ds4_session_cancel(ds4_session *s) {
if (!s) return;
s->cancel_requested = true;
}

#ifndef DS4_NO_GPU
typedef struct {
ds4_session *session;
Expand Down Expand Up @@ -17203,6 +17215,7 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
snprintf(err, errlen, "prompt exceeds context");
return 1;
}
s->cancel_requested = false;
if (ds4_session_is_cpu(s)) {
ds4_engine *e = s->engine;
if (s->checkpoint_valid &&
Expand All @@ -17223,6 +17236,11 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
&s->cpu_scratch);
token_vec_push(&s->checkpoint, prompt->v[i]);
if (s->progress) s->progress(s->progress_ud, "prefill_chunk", i + 1, prompt->len);
if (s->cancel_requested) {
snprintf(err, errlen, "prefill cancelled");
s->checkpoint_valid = false;
return 1;
}
}
s->checkpoint_valid = true;
return 0;
Expand All @@ -17241,6 +17259,11 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
s->checkpoint_valid = true;
s->mtp_draft_valid = false;
if (s->progress) s->progress(s->progress_ud, "prefill_chunk", prompt->len, prompt->len);
if (s->cancel_requested) {
snprintf(err, errlen, "prefill cancelled");
s->checkpoint_valid = false;
return 1;
}
return 0;
}
#ifdef DS4_NO_GPU
Expand Down Expand Up @@ -17278,9 +17301,14 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
false,
progress_fn,
progress_fn ? &progress : NULL,
&s->cancel_requested,
NULL);
if (!ok) {
snprintf(err, errlen, "%s resumed prefill failed while extending checkpoint", backend_name);
if (s->cancel_requested) {
snprintf(err, errlen, "%s resumed prefill cancelled", backend_name);
} else {
snprintf(err, errlen, "%s resumed prefill failed while extending checkpoint", backend_name);
}
s->checkpoint_valid = false;
return 1;
}
Expand Down Expand Up @@ -17316,13 +17344,18 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
s->progress ? ds4_session_note_prefill_progress : NULL;
ok = metal_graph_prefill_chunked(&s->graph, &e->model, &e->weights,
prompt, prompt->len, s->logits, false,
progress_fn, progress_fn ? &progress : NULL);
progress_fn, progress_fn ? &progress : NULL,
&s->cancel_requested);
} else {
ok = metal_graph_prefill_raw_swa(&s->graph, &e->model, &e->weights,
prompt, prompt->len, s->logits, false);
}
if (!ok) {
snprintf(err, errlen, "%s prefill failed", backend_name);
if (s->cancel_requested) {
snprintf(err, errlen, "%s prefill cancelled", backend_name);
} else {
snprintf(err, errlen, "%s prefill failed", backend_name);
}
s->checkpoint_valid = false;
return 1;
}
Expand Down
1 change: 1 addition & 0 deletions ds4.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ int ds4_token_eos(ds4_engine *e);
int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size);
void ds4_session_free(ds4_session *s);
void ds4_session_set_progress(ds4_session *s, ds4_session_progress_fn fn, void *ud);
void ds4_session_cancel(ds4_session *s);

typedef enum {
DS4_SESSION_REWRITE_ERROR = -1,
Expand Down
27 changes: 27 additions & 0 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -6534,8 +6534,10 @@ typedef struct {
req_kind kind;
int prompt_tokens;
int cached_tokens;
int fd;
char ctx[48];
bool has_tools;
bool is_cancelled;
double t0;
double last_t;
int last_current;
Expand Down Expand Up @@ -6650,6 +6652,28 @@ static void log_tool_calls_summary(const char *ctx, const tool_calls *calls) {
buf_free(&names);
}

static void server_prefill_check_client(server_prefill_progress *p) {
if (!p || p->is_cancelled || p->fd < 0) return;
struct pollfd pfd = {.fd = p->fd, .events = POLLIN};
int rc;
do {
rc = poll(&pfd, 1, 0);
} while (rc < 0 && errno == EINTR);
bool is_gone = rc > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLNVAL));
if (!is_gone && rc > 0 && (pfd.revents & POLLIN)) {
char c;
ssize_t n = recv(p->fd, &c, 1, MSG_PEEK);
is_gone = n == 0 || (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR);
}
if (!is_gone) return;
p->is_cancelled = true;
if (p->srv) ds4_session_cancel(p->srv->session);
server_log(DS4_LOG_PREFILL,
"ds4-server: %s ctx=%s prefill cancelled: client disconnected",
p->kind == REQ_CHAT ? "chat" : "completion",
p->ctx);
}

static void server_progress_cb(void *ud, const char *event, int current, int total) {
server_prefill_progress *p = ud;
if (!p || !event || strcmp(event, "prefill_chunk")) return;
Expand All @@ -6658,6 +6682,7 @@ static void server_progress_cb(void *ud, const char *event, int current, int tot
double elapsed = now - p->t0;
if (p->seen && current == p->last_current) {
if (p->srv && current > p->cached_tokens) kv_cache_maybe_store_continued(p->srv);
server_prefill_check_client(p);
return;
}
int display_start = p->cached_tokens;
Expand Down Expand Up @@ -6694,6 +6719,7 @@ static void server_progress_cb(void *ud, const char *event, int current, int tot
avg_tps,
elapsed);
if (p->srv && current > p->cached_tokens) kv_cache_maybe_store_continued(p->srv);
server_prefill_check_client(p);
}

static char *build_tool_checkpoint_suffix(const request *r, const char *content,
Expand Down Expand Up @@ -6986,6 +7012,7 @@ static void generate_job(server *s, job *j) {
.kind = j->req.kind,
.prompt_tokens = prompt_tokens,
.cached_tokens = cached,
.fd = j->fd,
.has_tools = j->req.has_tools,
.t0 = t0,
};
Expand Down