From adf2fd67d5dd722cbc2d199e2e7c1be0cad53459 Mon Sep 17 00:00:00 2001 From: Michael Sitarzewski Date: Sat, 21 Mar 2026 10:16:00 -0500 Subject: [PATCH] feat: live dashboard monitor + serve loop improvements Dashboard: - ncurses-based htop-style terminal monitor (dashboard.c) - Reads /tmp/flash-moe-stats.json written by the inference server - Shows real-time status, progress bars, TTFT, tok/s, rolling averages - Auto-adapts to terminal width, clean exit with q or Ctrl+C Serve loop: - SSE streaming with per-token delta events - Dashboard stats reporting (server state, prefill progress, generation metrics) - Tool call parsing from model output ( blocks) - Session state save/restore for multi-turn conversations - GPU KV buffer increased to 32K pre-allocation - CPU 2-bit expert forward path for fallback compute Co-Authored-By: Claude Opus 4.6 (1M context) --- metal_infer/Makefile | 16 +- metal_infer/dashboard.c | 574 +++++++++++++++++++ metal_infer/infer.m | 1195 ++++++++++++++++++++++++++++++++++----- 3 files changed, 1640 insertions(+), 145 deletions(-) create mode 100644 metal_infer/dashboard.c diff --git a/metal_infer/Makefile b/metal_infer/Makefile index 3ae2b3a..e678b85 100644 --- a/metal_infer/Makefile +++ b/metal_infer/Makefile @@ -38,7 +38,11 @@ INFER_SRC = infer.m CHAT_TARGET = chat CHAT_SRC = chat.m -.PHONY: all clean run verify bench moe moebench full fullbench fast metallib infer infer-run chat-run build-chat +# Dashboard (htop-style monitor) +DASHBOARD_TARGET = dashboard +DASHBOARD_SRC = dashboard.c + +.PHONY: all clean run verify bench moe moebench full fullbench fast metallib infer infer-run chat-run build-chat dashboard dash-run all: $(TARGET) $(INFER_TARGET) @@ -63,8 +67,12 @@ $(INFER_TARGET): $(INFER_SRC) $(CHAT_TARGET): $(CHAT_SRC) linenoise.c linenoise.h $(CC) -O2 -Wall -fobjc-arc -framework Foundation $(CHAT_SRC) linenoise.c -o $(CHAT_TARGET) +# Build the dashboard monitor (ncurses) +$(DASHBOARD_TARGET): $(DASHBOARD_SRC) + $(CC) -O2 -Wall $(DASHBOARD_SRC) -lncurses -o $(DASHBOARD_TARGET) + clean: - rm -f $(TARGET) $(INFER_TARGET) $(CHAT_TARGET) $(SHADER_AIR) $(SHADER_LIB) + rm -f $(TARGET) $(INFER_TARGET) $(CHAT_TARGET) $(DASHBOARD_TARGET) $(SHADER_AIR) $(SHADER_LIB) # Run targets run: $(TARGET) @@ -101,3 +109,7 @@ infer-run: $(INFER_TARGET) chat-run: $(CHAT_TARGET) ./$(CHAT_TARGET) --k 4 + +# Dashboard targets +dash-run: $(DASHBOARD_TARGET) + ./$(DASHBOARD_TARGET) diff --git a/metal_infer/dashboard.c b/metal_infer/dashboard.c new file mode 100644 index 0000000..27eba9a --- /dev/null +++ b/metal_infer/dashboard.c @@ -0,0 +1,574 @@ +/* + * Flash-MoE Dashboard — htop-style terminal monitor (ncurses) + * + * Reads /tmp/flash-moe-stats.json (written by the inference server) + * and renders a live terminal dashboard every 500ms. + * + * Build: make dashboard + * Run: ./dashboard [--port PORT] + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +// ---- Stats structure ---- +typedef struct { + char state[32]; + char request_id[64]; + int prefill_tokens; + int prefill_done; + int gen_tokens; + int gen_max; + double tok_per_sec; + double elapsed_ms; + double ttft_ms; + int think_tokens; + int total_requests; + double uptime_s; + char model[64]; + char quant[16]; + int k; + int port; + int connected; +} Stats; + +// ---- Volatile flag for clean exit ---- +static volatile int g_running = 1; + +static void handle_sigint(int sig) { + (void)sig; + g_running = 0; +} + +// ---- Simple JSON string extractor ---- +static int json_get_str(const char *json, const char *key, char *dst, int dst_size) { + char pattern[128]; + snprintf(pattern, sizeof(pattern), "\"%s\"", key); + const char *p = strstr(json, pattern); + if (!p) { dst[0] = '\0'; return 0; } + p += strlen(pattern); + while (*p == ' ' || *p == ':' || *p == '\t') p++; + if (*p != '"') { dst[0] = '\0'; return 0; } + p++; + int i = 0; + while (*p && *p != '"' && i < dst_size - 1) { + dst[i++] = *p++; + } + dst[i] = '\0'; + return 1; +} + +static double json_get_num(const char *json, const char *key) { + char pattern[128]; + snprintf(pattern, sizeof(pattern), "\"%s\"", key); + const char *p = strstr(json, pattern); + if (!p) return 0; + p += strlen(pattern); + while (*p == ' ' || *p == ':' || *p == '\t') p++; + return atof(p); +} + +// ---- Read stats from file ---- +static void read_stats(Stats *s) { + memset(s, 0, sizeof(*s)); + + FILE *f = fopen("/tmp/flash-moe-stats.json", "r"); + if (!f) return; + + char buf[4096]; + size_t n = fread(buf, 1, sizeof(buf) - 1, f); + fclose(f); + if (n == 0) return; + buf[n] = '\0'; + + struct stat st; + if (stat("/tmp/flash-moe-stats.json", &st) == 0) { + time_t now_t = time(NULL); + if (now_t - st.st_mtime > 120) return; + } + + s->connected = 1; + json_get_str(buf, "state", s->state, sizeof(s->state)); + json_get_str(buf, "request_id", s->request_id, sizeof(s->request_id)); + json_get_str(buf, "model", s->model, sizeof(s->model)); + json_get_str(buf, "quant", s->quant, sizeof(s->quant)); + + s->prefill_tokens = (int)json_get_num(buf, "prefill_tokens"); + s->prefill_done = (int)json_get_num(buf, "prefill_done"); + s->gen_tokens = (int)json_get_num(buf, "gen_tokens"); + s->gen_max = (int)json_get_num(buf, "gen_max"); + s->tok_per_sec = json_get_num(buf, "tok_per_sec"); + s->elapsed_ms = json_get_num(buf, "elapsed_ms"); + s->ttft_ms = json_get_num(buf, "ttft_ms"); + s->think_tokens = (int)json_get_num(buf, "think_tokens"); + s->total_requests = (int)json_get_num(buf, "total_requests"); + s->uptime_s = json_get_num(buf, "uptime_s"); + s->k = (int)json_get_num(buf, "k"); + s->port = (int)json_get_num(buf, "port"); +} + +// ---- Format uptime ---- +static void format_uptime(double secs, char *buf, int buf_size) { + int s = (int)secs; + int h = s / 3600; + int m = (s % 3600) / 60; + if (h > 0) + snprintf(buf, buf_size, "%dh %02dm", h, m); + else if (m > 0) + snprintf(buf, buf_size, "%dm %02ds", m, s % 60); + else + snprintf(buf, buf_size, "%ds", s); +} + +// ---- Color pairs ---- +#define CP_BORDER 1 +#define CP_GREEN 2 +#define CP_YELLOW 3 +#define CP_RED 4 +#define CP_MAGENTA 5 +#define CP_GRAY 6 +#define CP_BAR_FILL 7 +#define CP_BAR_BG 8 + +// ---- Rolling averages ---- +static double g_tok_history[120]; +static int g_tok_count = 0; +static double g_ttft_history[1000]; +static int g_ttft_count = 0; +static int g_last_requests = 0; + +// ---- Draw a progress bar using ncurses ---- +static void draw_bar(WINDOW *win, int y, int x, int width, double fraction, + int fill_color, int bg_color) { + if (fraction < 0) fraction = 0; + if (fraction > 1) fraction = 1; + int filled = (int)(fraction * width + 0.5); + if (filled > width) filled = width; + + wmove(win, y, x); + wattron(win, COLOR_PAIR(fill_color)); + for (int i = 0; i < filled; i++) + waddch(win, ACS_CKBOARD); + wattroff(win, COLOR_PAIR(fill_color)); + + wattron(win, COLOR_PAIR(bg_color)); + for (int i = filled; i < width; i++) + waddch(win, ACS_CKBOARD); + wattroff(win, COLOR_PAIR(bg_color)); +} + +// ---- Render ---- +static void render(const Stats *s, int port_arg) { + int term_h, term_w; + getmaxyx(stdscr, term_h, term_w); + (void)term_h; + + int box_w = term_w - 4; + if (box_w < 40) box_w = 40; + if (box_w > 120) box_w = 120; + int box_x = 2; + int inner_w = box_w - 2; // content area inside borders + + char uptime_str[64]; + format_uptime(s->uptime_s, uptime_str, sizeof(uptime_str)); + + // Track rolling averages + if (s->connected && s->tok_per_sec > 0) { + if (g_tok_count < 120) { + g_tok_history[g_tok_count++] = s->tok_per_sec; + } else { + memmove(g_tok_history, g_tok_history + 1, 119 * sizeof(double)); + g_tok_history[119] = s->tok_per_sec; + } + } + if (s->connected && s->ttft_ms > 0 && s->total_requests > g_last_requests) { + if (g_ttft_count < 1000) + g_ttft_history[g_ttft_count++] = s->ttft_ms; + g_last_requests = s->total_requests; + } + + double avg_tok = 0; + for (int i = 0; i < g_tok_count; i++) avg_tok += g_tok_history[i]; + if (g_tok_count > 0) avg_tok /= g_tok_count; + + double avg_ttft = 0; + for (int i = 0; i < g_ttft_count; i++) avg_ttft += g_ttft_history[i]; + if (g_ttft_count > 0) avg_ttft /= g_ttft_count; + + erase(); + + // ---- Section 1: Title ---- + int row = 0; + + // Top border + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_ULCORNER); + for (int i = 0; i < inner_w; i++) addch(ACS_HLINE); + addch(ACS_URCORNER); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + // Title line + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x + inner_w + 1, ' '); // clear for right border + + attron(A_BOLD); + mvprintw(row, box_x + 2, "Flash-MoE Dashboard"); + attroff(A_BOLD); + + if (s->connected && s->tok_per_sec > 0) { + attron(COLOR_PAIR(CP_GREEN) | A_BOLD); + mvprintw(row, box_x + 24, "%5.1f tok/s", s->tok_per_sec); + attroff(COLOR_PAIR(CP_GREEN) | A_BOLD); + attron(COLOR_PAIR(CP_GRAY)); + printw(" up %s", uptime_str); + attroff(COLOR_PAIR(CP_GRAY)); + } else if (s->connected) { + attron(COLOR_PAIR(CP_GRAY)); + mvprintw(row, box_x + 44, "up %s", uptime_str); + attroff(COLOR_PAIR(CP_GRAY)); + } + + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x + inner_w + 1, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + // Model info line + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + + if (s->connected) { + attron(COLOR_PAIR(CP_GRAY)); + mvprintw(row, box_x + 2, "%s %s K=%d Port %d", + s->model, s->quant, s->k, s->port > 0 ? s->port : port_arg); + attroff(COLOR_PAIR(CP_GRAY)); + } else { + attron(COLOR_PAIR(CP_GRAY)); + mvprintw(row, box_x + 2, "Waiting for server on port %d...", port_arg); + attroff(COLOR_PAIR(CP_GRAY)); + } + + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x + inner_w + 1, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + // Divider + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_LTEE); + for (int i = 0; i < inner_w; i++) addch(ACS_HLINE); + addch(ACS_RTEE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + if (!s->connected) { + // Disconnected state + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + attron(COLOR_PAIR(CP_RED) | A_BOLD); + mvprintw(row, box_x + 2, "DISCONNECTED"); + attroff(COLOR_PAIR(CP_RED) | A_BOLD); + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x + inner_w + 1, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + attron(COLOR_PAIR(CP_RED)); + mvprintw(row, box_x + 2, "Server not responding."); + attroff(COLOR_PAIR(CP_RED)); + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x + inner_w + 1, ACS_VLINE); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + // Bottom border + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_LLCORNER); + for (int i = 0; i < inner_w; i++) addch(ACS_HLINE); + addch(ACS_LRCORNER); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + attron(COLOR_PAIR(CP_GRAY)); + mvprintw(row + 1, box_x, "Press Ctrl+C to exit"); + attroff(COLOR_PAIR(CP_GRAY)); + + refresh(); + return; + } + + int is_idle = (strcmp(s->state, "idle") == 0); + int is_prefilling = (strcmp(s->state, "prefilling") == 0); + int is_generating = (strcmp(s->state, "generating") == 0); + + // ---- Helper macro for bordered lines ---- + #define BORDER_LEFT() do { \ + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); \ + mvaddch(row, box_x, ACS_VLINE); \ + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); \ + } while(0) + + #define BORDER_RIGHT() do { \ + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); \ + mvaddch(row, box_x + inner_w + 1, ACS_VLINE); \ + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); \ + } while(0) + + #define DIVIDER() do { \ + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); \ + mvaddch(row, box_x, ACS_LTEE); \ + for (int _i = 0; _i < inner_w; _i++) addch(ACS_HLINE); \ + addch(ACS_RTEE); \ + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); \ + row++; \ + } while(0) + + // ---- Section 2: Status ---- + + // Status line + BORDER_LEFT(); + if (is_generating) { + mvprintw(row, box_x + 2, "Status: "); + attron(COLOR_PAIR(CP_GREEN) | A_BOLD); + printw("GENERATING"); + attroff(COLOR_PAIR(CP_GREEN) | A_BOLD); + printw(" "); + double gen_frac = s->gen_max > 0 ? (double)s->gen_tokens / s->gen_max : 0; + draw_bar(stdscr, row, box_x + 22, 10, gen_frac, CP_GREEN, CP_GRAY); + printw(" %d/%d tokens", s->gen_tokens, s->gen_max); + } else if (is_prefilling) { + mvprintw(row, box_x + 2, "Status: "); + attron(COLOR_PAIR(CP_YELLOW) | A_BOLD); + printw("PREFILLING"); + attroff(COLOR_PAIR(CP_YELLOW) | A_BOLD); + printw(" "); + double pf_frac = s->prefill_tokens > 0 ? (double)s->prefill_done / s->prefill_tokens : 0; + draw_bar(stdscr, row, box_x + 22, 10, pf_frac, CP_YELLOW, CP_GRAY); + printw(" %d/%d tokens", s->prefill_done, s->prefill_tokens); + } else { + mvprintw(row, box_x + 2, "Status: "); + attron(A_DIM); + printw("IDLE"); + attroff(A_DIM); + } + BORDER_RIGHT(); + row++; + + // Request line + BORDER_LEFT(); + if (!is_idle && s->request_id[0]) { + mvprintw(row, box_x + 2, "Request: %s Elapsed: %.1fs", + s->request_id, s->elapsed_ms / 1000.0); + } else if (s->request_id[0]) { + attron(COLOR_PAIR(CP_GRAY)); + mvprintw(row, box_x + 2, "Last: %s", s->request_id); + attroff(COLOR_PAIR(CP_GRAY)); + } + BORDER_RIGHT(); + row++; + + // TTFT + Think line + BORDER_LEFT(); + if (!is_idle) { + if (s->ttft_ms > 0 && s->think_tokens > 0) { + mvprintw(row, box_x + 2, "TTFT: %.1fs Think: ", s->ttft_ms / 1000.0); + attron(COLOR_PAIR(CP_MAGENTA)); + printw("%d tokens", s->think_tokens); + attroff(COLOR_PAIR(CP_MAGENTA)); + } else if (s->ttft_ms > 0) { + mvprintw(row, box_x + 2, "TTFT: %.1fs", s->ttft_ms / 1000.0); + } else { + attron(COLOR_PAIR(CP_GRAY)); + mvprintw(row, box_x + 2, "TTFT: --"); + attroff(COLOR_PAIR(CP_GRAY)); + } + } + BORDER_RIGHT(); + row++; + + // ---- Section 3: Progress bars ---- + DIVIDER(); + + // Prefill bar + { + double pf_frac = s->prefill_tokens > 0 ? (double)s->prefill_done / s->prefill_tokens : 0; + char suffix[64]; + snprintf(suffix, sizeof(suffix), " %3.0f%% (%d tokens)", pf_frac * 100, s->prefill_tokens); + int label_w = 10; // "Prefill: " + int suffix_w = (int)strlen(suffix); + int bar_w = inner_w - label_w - suffix_w; + if (bar_w < 5) bar_w = 5; + + BORDER_LEFT(); + int bar_color = (pf_frac >= 1.0) ? CP_GREEN : (is_prefilling ? CP_YELLOW : CP_GRAY); + attron(COLOR_PAIR(CP_GREEN)); + mvprintw(row, box_x + 2, "Prefill:"); + attroff(COLOR_PAIR(CP_GREEN)); + printw(" "); + draw_bar(stdscr, row, box_x + 2 + label_w, bar_w, pf_frac, bar_color, CP_BAR_BG); + attron(COLOR_PAIR(CP_GREEN)); + mvprintw(row, box_x + 2 + label_w + bar_w, "%s", suffix); + attroff(COLOR_PAIR(CP_GREEN)); + BORDER_RIGHT(); + row++; + } + + // Generate bar + { + double gen_frac = s->gen_max > 0 ? (double)s->gen_tokens / s->gen_max : 0; + char suffix[64]; + snprintf(suffix, sizeof(suffix), " %3.0f%% (%d/%d)", gen_frac * 100, s->gen_tokens, s->gen_max); + int label_w = 10; // "Generate: " + int suffix_w = (int)strlen(suffix); + int bar_w = inner_w - label_w - suffix_w; + if (bar_w < 5) bar_w = 5; + + BORDER_LEFT(); + attron(COLOR_PAIR(CP_GREEN)); + mvprintw(row, box_x + 2, "Generate:"); + attroff(COLOR_PAIR(CP_GREEN)); + printw(" "); + draw_bar(stdscr, row, box_x + 2 + label_w, bar_w, gen_frac, + is_generating ? CP_GREEN : CP_GRAY, CP_BAR_BG); + attron(COLOR_PAIR(CP_GREEN)); + mvprintw(row, box_x + 2 + label_w + bar_w, "%s", suffix); + attroff(COLOR_PAIR(CP_GREEN)); + BORDER_RIGHT(); + row++; + } + + // ---- Section 4: Lifetime stats ---- + DIVIDER(); + + // Row 1 + BORDER_LEFT(); + attron(COLOR_PAIR(CP_GREEN)); + mvprintw(row, box_x + 2, "Lifetime: "); + attroff(COLOR_PAIR(CP_GREEN)); + printw("%d requests", s->total_requests); + attron(COLOR_PAIR(CP_GRAY)); + printw(" | "); + attroff(COLOR_PAIR(CP_GRAY)); + attron(COLOR_PAIR(CP_GREEN)); + printw("Avg TTFT: "); + attroff(COLOR_PAIR(CP_GREEN)); + if (avg_ttft > 0) + printw("%.1fs", avg_ttft / 1000.0); + else { + attron(COLOR_PAIR(CP_GRAY)); + printw("--"); + attroff(COLOR_PAIR(CP_GRAY)); + } + BORDER_RIGHT(); + row++; + + // Row 2 + BORDER_LEFT(); + attron(COLOR_PAIR(CP_GREEN)); + mvprintw(row, box_x + 2, "Avg tok/s: "); + attroff(COLOR_PAIR(CP_GREEN)); + if (avg_tok > 0) + printw("%.1f", avg_tok); + else { + attron(COLOR_PAIR(CP_GRAY)); + printw("--"); + attroff(COLOR_PAIR(CP_GRAY)); + } + attron(COLOR_PAIR(CP_GRAY)); + printw(" | "); + attroff(COLOR_PAIR(CP_GRAY)); + attron(COLOR_PAIR(CP_GREEN)); + printw("Uptime: "); + attroff(COLOR_PAIR(CP_GREEN)); + printw("%s", uptime_str); + BORDER_RIGHT(); + row++; + + // Bottom border + attron(COLOR_PAIR(CP_BORDER) | A_BOLD); + mvaddch(row, box_x, ACS_LLCORNER); + for (int i = 0; i < inner_w; i++) addch(ACS_HLINE); + addch(ACS_LRCORNER); + attroff(COLOR_PAIR(CP_BORDER) | A_BOLD); + row++; + + attron(COLOR_PAIR(CP_GRAY)); + mvprintw(row + 1, box_x, "Press Ctrl+C to exit"); + attroff(COLOR_PAIR(CP_GRAY)); + + refresh(); + + #undef BORDER_LEFT + #undef BORDER_RIGHT + #undef DIVIDER +} + +// ---- Main ---- +int main(int argc, char **argv) { + int port = 6601; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--port") == 0 && i + 1 < argc) { + port = atoi(argv[++i]); + } else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { + printf("Usage: %s [--port PORT]\n", argv[0]); + printf(" --port PORT Server port (default: 6601)\n"); + printf("\nReads /tmp/flash-moe-stats.json written by infer --serve\n"); + return 0; + } + } + + signal(SIGINT, handle_sigint); + signal(SIGTERM, handle_sigint); + + // Init ncurses + initscr(); + cbreak(); + noecho(); + curs_set(0); + timeout(0); // non-blocking getch + + if (has_colors()) { + start_color(); + use_default_colors(); + init_pair(CP_BORDER, COLOR_CYAN, -1); + init_pair(CP_GREEN, COLOR_GREEN, -1); + init_pair(CP_YELLOW, COLOR_YELLOW, -1); + init_pair(CP_RED, COLOR_RED, -1); + init_pair(CP_MAGENTA, COLOR_MAGENTA, -1); + init_pair(CP_GRAY, COLOR_WHITE, -1); // closest to gray + init_pair(CP_BAR_FILL, COLOR_GREEN, COLOR_GREEN); + init_pair(CP_BAR_BG, COLOR_WHITE, -1); + } + + Stats stats; + + while (g_running) { + int ch = getch(); + if (ch == 'q' || ch == 'Q') break; + + read_stats(&stats); + render(&stats, port); + usleep(500 * 1000); + } + + // Cleanup + endwin(); + printf("Dashboard exited.\n"); + + return 0; +} diff --git a/metal_infer/infer.m b/metal_infer/infer.m index 6f3bd24..bc7863c 100644 --- a/metal_infer/infer.m +++ b/metal_infer/infer.m @@ -115,7 +115,7 @@ // KV cache maximum context length #define MAX_SEQ_LEN 1048576 // 1M context — only 15 full-attn layers need KV cache, ~15GB at max -#define GPU_KV_SEQ 8192 // GPU KV buffer pre-allocation (grows if exceeded, falls back to CPU attn) +#define GPU_KV_SEQ 32768 // GPU KV buffer pre-allocation (grows if exceeded, falls back to CPU attn) // Special tokens #define EOS_TOKEN_1 248046 @@ -168,6 +168,47 @@ static double now_ms(void) { static int g_expert_freq[NUM_LAYERS][NUM_EXPERTS]; // activation count per (layer, expert) static int g_freq_tracking = 0; // enabled by --freq flag static int g_use_2bit = 0; // enabled by --2bit flag: use packed_experts_2bit/ + 2-bit kernel + +// ---- Dashboard stats (written to /tmp/flash-moe-stats.json) ---- +static volatile int g_serve_state = 0; // 0=idle, 1=prefilling, 2=generating +static volatile int g_prefill_total = 0; +static volatile int g_prefill_done = 0; +static volatile int g_gen_tokens = 0; +static volatile int g_gen_max = 0; +static volatile int g_think_tokens_serve = 0; +static volatile double g_tok_per_sec = 0; +static volatile double g_elapsed_ms = 0; +static volatile double g_ttft_ms = 0; +static volatile int g_total_requests = 0; +static double g_serve_start_time = 0; +static int g_serve_port = 0; +static int g_serve_K = 0; +static char g_current_request_id[64] = {0}; + +static void write_stats_file(void) { + double uptime = g_serve_start_time > 0 ? (now_ms() - g_serve_start_time) / 1000.0 : 0; + char tmp_path[] = "/tmp/flash-moe-stats.json.tmp.XXXXXX"; + int fd = mkstemp(tmp_path); + if (fd < 0) return; + char buf[2048]; + int n = snprintf(buf, sizeof(buf), + "{\"state\":\"%s\",\"request_id\":\"%s\"," + "\"prefill_tokens\":%d,\"prefill_done\":%d," + "\"gen_tokens\":%d,\"gen_max\":%d," + "\"tok_per_sec\":%.2f,\"elapsed_ms\":%.1f,\"ttft_ms\":%.1f," + "\"think_tokens\":%d,\"total_requests\":%d,\"uptime_s\":%.1f," + "\"model\":\"qwen3.5-397b-a17b\",\"quant\":\"%s\",\"k\":%d,\"port\":%d}\n", + g_serve_state == 0 ? "idle" : g_serve_state == 1 ? "prefilling" : "generating", + g_current_request_id, + g_prefill_total, g_prefill_done, + g_gen_tokens, g_gen_max, + g_tok_per_sec, g_elapsed_ms, g_ttft_ms, + g_think_tokens_serve, g_total_requests, uptime, + g_use_2bit ? "2-bit" : "4-bit", g_serve_K, g_serve_port); + write(fd, buf, n); + close(fd); + rename(tmp_path, "/tmp/flash-moe-stats.json"); +} static int g_cache_telemetry_enabled = 0; // enabled by --cache-telemetry flag static int g_think_budget = 2048; // max thinking tokens before force-emitting @@ -701,6 +742,44 @@ static void cpu_dequant_matvec( } } +// 2-bit dequant matvec: out[out_dim] = W * x[in_dim] +// W is stored as packed uint32 (16 x 2-bit values per uint32) +// scales/biases are bfloat16 per group +static void cpu_dequant_matvec_2bit( + const uint32_t *W, const uint16_t *scales, const uint16_t *biases, + const float *x, float *out, + int out_dim, int in_dim, int group_size +) { + int num_groups = in_dim / group_size; + int packed_per_group = group_size / 16; // 16 values per uint32 at 2-bit + int packed_cols = in_dim / 16; + + for (int row = 0; row < out_dim; row++) { + float acc = 0.0f; + const uint32_t *w_row = W + row * packed_cols; + const uint16_t *s_row = scales + row * num_groups; + const uint16_t *b_row = biases + row * num_groups; + + for (int g = 0; g < num_groups; g++) { + float scale = bf16_to_f32(s_row[g]); + float bias = bf16_to_f32(b_row[g]); + int base_packed = g * packed_per_group; + int base_x = g * group_size; + + for (int p = 0; p < packed_per_group; p++) { + uint32_t packed = w_row[base_packed + p]; + int x_base = base_x + p * 16; + + for (int n = 0; n < 16; n++) { + uint32_t val = (packed >> (n * 2)) & 0x3; + acc += ((float)val * scale + bias) * x[x_base + n]; + } + } + } + out[row] = acc; + } +} + // RMS normalization: out = x * w / rms(x) static void cpu_rms_norm(const float *x, const uint16_t *w_bf16, float *out, int dim, float eps) { float sum_sq = 0.0f; @@ -2699,8 +2778,9 @@ static void moe_forward( off_t expert_offset = (off_t)eidx * esz; if (g_metal && g_metal->buf_expert_data) { - // GPU path: pread directly into Metal buffer, run gate+up+swiglu+down on GPU + // GPU path: load expert into Metal buffer, run gate+up+swiglu+down on GPU void *expert_buf_ptr = [g_metal->buf_expert_data contents]; + ssize_t nread = pread(packed_fd, expert_buf_ptr, esz, expert_offset); if (nread != (ssize_t)esz) { fprintf(stderr, "WARNING: layer %d expert %d pread: %zd/%zu\n", @@ -3788,6 +3868,64 @@ static void discard_deferred_experts(void) { } } +// ============================================================================ +// ============================================================================ +// CPU expert forward: dequant matvec gate+up -> SwiGLU -> down, using cached +// expert data in CPU malloc. Supports both 4-bit and 2-bit quantization. +// Output is accumulated: output += weight * expert_output +// ============================================================================ + +static void cpu_expert_forward( + const void *expert_data, // expert weight bytes (EXPERT_SIZE or EXPERT_SIZE_2BIT) + const float *input, // [HIDDEN_DIM=4096] + float *output, // [HIDDEN_DIM=4096] accumulated (+=) + float weight, // routing weight + int use_2bit // 1 for 2-bit, 0 for 4-bit +) { + // Select offsets based on quantization + size_t gate_w_off, gate_s_off, gate_b_off; + size_t up_w_off, up_s_off, up_b_off; + size_t down_w_off, down_s_off, down_b_off; + if (use_2bit) { + gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2; + up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; + down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; + } else { + gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224; + up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520; + down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816; + } + + uint32_t *gw = (uint32_t *)((const char *)expert_data + gate_w_off); + uint16_t *gs = (uint16_t *)((const char *)expert_data + gate_s_off); + uint16_t *gb = (uint16_t *)((const char *)expert_data + gate_b_off); + uint32_t *uw = (uint32_t *)((const char *)expert_data + up_w_off); + uint16_t *us = (uint16_t *)((const char *)expert_data + up_s_off); + uint16_t *ub = (uint16_t *)((const char *)expert_data + up_b_off); + uint32_t *dw = (uint32_t *)((const char *)expert_data + down_w_off); + uint16_t *ds = (uint16_t *)((const char *)expert_data + down_s_off); + uint16_t *db = (uint16_t *)((const char *)expert_data + down_b_off); + + float gate_out[MOE_INTERMEDIATE]; + float up_out[MOE_INTERMEDIATE]; + float act_out[MOE_INTERMEDIATE]; + float expert_out[HIDDEN_DIM]; + + if (use_2bit) { + cpu_dequant_matvec_2bit(gw, gs, gb, input, gate_out, MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); + cpu_dequant_matvec_2bit(uw, us, ub, input, up_out, MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); + cpu_swiglu(gate_out, up_out, act_out, MOE_INTERMEDIATE); + cpu_dequant_matvec_2bit(dw, ds, db, act_out, expert_out, HIDDEN_DIM, MOE_INTERMEDIATE, GROUP_SIZE); + } else { + cpu_dequant_matvec(gw, gs, gb, input, gate_out, MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); + cpu_dequant_matvec(uw, us, ub, input, up_out, MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); + cpu_swiglu(gate_out, up_out, act_out, MOE_INTERMEDIATE); + cpu_dequant_matvec(dw, ds, db, act_out, expert_out, HIDDEN_DIM, MOE_INTERMEDIATE, GROUP_SIZE); + } + + cpu_vec_madd(output, expert_out, weight, HIDDEN_DIM); +} + // ============================================================================ // Fused layer forward: GPU/CPU overlap + deferred expert pipeline // @@ -3951,6 +4089,7 @@ static void fused_layer_forward( float *residual = s_residual; id cmd1 = nil; int gpu_linear_attn = 0; // set to 1 if GPU handles entire linear attention pipeline + int merge_cmd12 = 0; // set to 1 to merge CMD1+CMD2 into single command buffer (prefill optimization) // Pre-compute linear_layer_idx for GPU linear attention encoding in CMD1 int linear_layer_idx = -1; @@ -4073,24 +4212,34 @@ static void fused_layer_forward( gpu_linear_attn = 1; } - [cmd1 commit]; + // When GPU handles entire linear attention pipeline, merge CMD1+CMD2 + // into one command buffer to eliminate a dispatch overhead (~0.35ms/layer). + // Safe because CMD1 writes batch_out[6] and CMD2 reads it — within one + // command buffer, encoders execute sequentially on the GPU. + // Don't merge in fast path — deferred state from previous layer needs finalization + // and the CMD3→CMD1 pipeline overlap already provides the dispatch savings + merge_cmd12 = 0; + if (!merge_cmd12) { + [cmd1 commit]; - if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd1_submit += t1 - t0; } + if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd1_submit += t1 - t0; } - // Wait for CMD1 (implies CMD3(N-1) also done, since queue is serial) - if (g_timing_enabled) { t0 = now_ms(); } - [cmd1 waitUntilCompleted]; - if (!gpu_linear_attn) { - gpu_flush_batch_results(g_metal, attn_specs, num_attn_specs); - } - if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd1_wait += t1 - t0; } + // Wait for CMD1 (implies CMD3(N-1) also done, since queue is serial) + if (g_timing_enabled) { t0 = now_ms(); } + [cmd1 waitUntilCompleted]; + if (!gpu_linear_attn) { + gpu_flush_batch_results(g_metal, attn_specs, num_attn_specs); + } + if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd1_wait += t1 - t0; } - // Now CMD3(N-1) is done. Read back hidden state from GPU. - if (g_timing_enabled) { t0 = now_ms(); } - finalize_deferred_experts(); // reads buf_moe_hidden -> hidden - // Set up residual for CMD2 (residual = hidden before this layer's attention) - cpu_vec_copy(residual, hidden, HIDDEN_DIM); - if (g_timing_enabled) { t1 = now_ms(); g_timing.deferred_cpu += t1 - t0; } + // Now CMD3(N-1) is done. Read back hidden state from GPU. + if (g_timing_enabled) { t0 = now_ms(); } + finalize_deferred_experts(); // reads buf_moe_hidden -> hidden + // Set up residual for CMD2 (residual = hidden before this layer's attention) + cpu_vec_copy(residual, hidden, HIDDEN_DIM); + if (g_timing_enabled) { t1 = now_ms(); g_timing.deferred_cpu += t1 - t0; } + } + // If merge_cmd12: cmd1 stays uncommitted, residual copy handled below // No input_norm needed — CMD3 already computed it into buf_input. // normed is only needed if speculative routing is enabled (currently disabled). @@ -4210,7 +4359,18 @@ static void fused_layer_forward( gpu_linear_attn = 1; } - [cmd1 commit]; + // Merge CMD1+CMD2 for GPU linear attention layers: + // No CPU work needed between CMD1 and CMD2, so keep cmd1 + // uncommitted and reuse it as cmd_fused below. + if (gpu_linear_attn) { + merge_cmd12 = 1; + // Copy residual now before CMD2 encoding overwrites buf_input. + // In the slow path (no prev_gpu_combined), residual == hidden + // and g_deferred.active == 0, so finalize_deferred_experts() is a no-op. + memcpy([g_metal->buf_residual contents], hidden, HIDDEN_DIM * sizeof(float)); + } else { + [cmd1 commit]; + } } else { for (int i = 0; i < num_attn_specs; i++) { BatchMatvecSpec *s = &attn_specs[i]; @@ -4220,9 +4380,9 @@ static void fused_layer_forward( } if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd1_submit += t1 - t0; } - // Wait for CMD1 + // Wait for CMD1 (skip if merging — cmd1 stays uncommitted) if (g_timing_enabled) { t0 = now_ms(); } - if (cmd1) { + if (cmd1 && !merge_cmd12) { [cmd1 waitUntilCompleted]; if (!gpu_linear_attn) { gpu_flush_batch_results(g_metal, attn_specs, num_attn_specs); @@ -4684,11 +4844,19 @@ static void fused_layer_forward( } // gpu_linear_attn: batch_out[6] already has the result from CMD1 gated_rms_norm // Copy residual into GPU buffer for residual_add kernel - memcpy([g_metal->buf_residual contents], residual, HIDDEN_DIM * sizeof(float)); + if (merge_cmd12) { + // In merged mode, residual = hidden (the pre-attention hidden state). + // Fast path: copy now (slow path already copied in the merge block above, + // but the value is the same so this is a harmless idempotent copy). + memcpy([g_metal->buf_residual contents], hidden, HIDDEN_DIM * sizeof(float)); + } else { + memcpy([g_metal->buf_residual contents], residual, HIDDEN_DIM * sizeof(float)); + } attn_out_for_oproj = NULL; - id cmd_fused = [g_metal->queue commandBuffer]; + // Reuse cmd1 if merging CMD1+CMD2, otherwise create new command buffer + id cmd_fused = merge_cmd12 ? cmd1 : [g_metal->queue commandBuffer]; // ---- GPU attention dispatches (only for full-attn layers with GPU path) ---- if (gpu_attn_fuse) { @@ -5535,6 +5703,396 @@ static void server_save_turn(const char *session_id, const char *role, const cha fclose(f); } +// Extract "stream" boolean from JSON body. Returns 1 for true (default), 0 for false. +static int extract_stream_flag(const char *buf) { + const char *p = strstr(buf, "\"stream\""); + if (!p) return 1; // default: streaming + p += 8; // skip "stream" + while (*p == ' ' || *p == '\t' || *p == ':' || *p == ' ') p++; + if (*p == 'f' || *p == '0') return 0; + return 1; +} + +// Unescape a JSON string value in-place. Handles \n, \t, \", \\, \uXXXX (ASCII subset). +// Returns length of unescaped string. +static int json_unescape_inplace(char *s, int len) { + char *r = s, *w = s; + char *end = s + len; + while (r < end) { + if (*r == '\\' && r + 1 < end) { + r++; + switch (*r) { + case 'n': *w++ = '\n'; r++; break; + case 't': *w++ = '\t'; r++; break; + case 'r': *w++ = '\r'; r++; break; + case '"': *w++ = '"'; r++; break; + case '\\': *w++ = '\\'; r++; break; + case '/': *w++ = '/'; r++; break; + case 'u': + // \uXXXX — decode hex, emit as UTF-8 if ASCII, else '?' + if (r + 4 < end) { + unsigned val = 0; + int ok = 1; + for (int i = 1; i <= 4; i++) { + char c = r[i]; + val <<= 4; + if (c >= '0' && c <= '9') val |= c - '0'; + else if (c >= 'a' && c <= 'f') val |= c - 'a' + 10; + else if (c >= 'A' && c <= 'F') val |= c - 'A' + 10; + else { ok = 0; break; } + } + if (ok) { + r += 5; + if (val < 0x80) { + *w++ = (char)val; + } else if (val < 0x800) { + *w++ = (char)(0xC0 | (val >> 6)); + *w++ = (char)(0x80 | (val & 0x3F)); + } else { + *w++ = (char)(0xE0 | (val >> 12)); + *w++ = (char)(0x80 | ((val >> 6) & 0x3F)); + *w++ = (char)(0x80 | (val & 0x3F)); + } + } else { + *w++ = '\\'; *w++ = 'u'; + r++; + } + } else { + *w++ = '\\'; *w++ = *r++; + } + break; + default: *w++ = '\\'; *w++ = *r++; break; + } + } else { + *w++ = *r++; + } + } + *w = '\0'; + return (int)(w - s); +} + +// Find the end of a JSON string starting at the opening quote (p points AFTER the opening "). +// Returns pointer to the closing " (unescaped), or NULL. +static char *find_json_string_end(char *p) { + while (*p) { + if (*p == '\\') { p += 2; continue; } + if (*p == '"') return p; + p++; + } + return NULL; +} + +// Extract a JSON string value for key starting at p (p points to the first char after the key's ":"). +// Skips whitespace, expects opening quote. Writes start/len of the raw (still-escaped) value. +// Returns pointer past the closing quote, or NULL on failure. +static char *json_extract_string_after_colon(char *p, char **out_start, int *out_len) { + while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') p++; + if (*p != '"') return NULL; + p++; // skip opening quote + *out_start = p; + char *end = find_json_string_end(p); + if (!end) return NULL; + *out_len = (int)(end - p); + return end + 1; // past closing quote +} + +// Build a Qwen chat template prompt from the OpenAI messages array. +// Parses "messages": [...] from body. Each message has "role" and "content". +// If tools_section is non-NULL, it gets appended to the first system message content. +// Returns malloc'd string with the full prompt, or NULL on failure. +static char *extract_messages_to_prompt(const char *body, const char *tools_section) { + // Find "messages" array + const char *mp = strstr(body, "\"messages\""); + if (!mp) return NULL; + mp += 10; + while (*mp == ' ' || *mp == '\t' || *mp == ':' || *mp == '\n' || *mp == '\r') mp++; + if (*mp != '[') return NULL; + mp++; // skip [ + + // We'll build the prompt in a dynamic buffer + size_t buf_cap = 64 * 1024; + size_t buf_len = 0; + char *buf = malloc(buf_cap); + if (!buf) return NULL; + buf[0] = '\0'; + + #define PROMPT_APPEND(s, slen) do { \ + while (buf_len + (slen) + 1 > buf_cap) { \ + buf_cap *= 2; \ + buf = realloc(buf, buf_cap); \ + if (!buf) return NULL; \ + } \ + memcpy(buf + buf_len, (s), (slen)); \ + buf_len += (slen); \ + buf[buf_len] = '\0'; \ + } while(0) + + #define PROMPT_APPEND_STR(s) do { \ + size_t _l = strlen(s); \ + PROMPT_APPEND((s), _l); \ + } while(0) + + // Track whether we've seen a system message (for tools injection) + int seen_system = 0; + + // Parse each message object in the array + char *p = (char *)mp; + while (*p) { + // Skip whitespace and commas + while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r' || *p == ',') p++; + if (*p == ']') break; // end of array + if (*p != '{') break; // not an object + p++; // skip { + + // Extract role and content from this message object + char *role_start = NULL; int role_len = 0; + char *content_start = NULL; int content_len = 0; + int brace_depth = 1; + char *obj_scan = p; + + // Scan through the object to find role and content fields + while (*obj_scan && brace_depth > 0) { + if (*obj_scan == '{') { brace_depth++; obj_scan++; continue; } + if (*obj_scan == '}') { brace_depth--; if (brace_depth == 0) break; obj_scan++; continue; } + if (*obj_scan == '"') { + // Check if this is "role" or "content" at top level (brace_depth == 1) + if (brace_depth == 1 && strncmp(obj_scan, "\"role\"", 6) == 0) { + obj_scan += 6; + while (*obj_scan == ' ' || *obj_scan == '\t' || *obj_scan == ':') obj_scan++; + obj_scan = json_extract_string_after_colon(obj_scan - 1, &role_start, &role_len); + if (!obj_scan) goto parse_fail; + // Back up: json_extract expects p after colon, we gave it one before + continue; + } + if (brace_depth == 1 && strncmp(obj_scan, "\"content\"", 9) == 0) { + obj_scan += 9; + while (*obj_scan == ' ' || *obj_scan == '\t' || *obj_scan == ':') obj_scan++; + // Content could be a string or null + if (*obj_scan == 'n') { + // null + obj_scan += 4; // skip "null" + content_start = NULL; + content_len = 0; + continue; + } + if (*obj_scan == '"') { + obj_scan++; // skip opening quote + content_start = obj_scan; + char *end = find_json_string_end(obj_scan); + if (!end) goto parse_fail; + content_len = (int)(end - obj_scan); + obj_scan = end + 1; + continue; + } + // Content could be an array (multi-part) — skip for now, take as empty + if (*obj_scan == '[') { + // Find the text part within content array + // Look for "text": "..." inside + int arr_depth = 1; + char *arr_p = obj_scan + 1; + while (*arr_p && arr_depth > 0) { + if (*arr_p == '[') arr_depth++; + else if (*arr_p == ']') { arr_depth--; if (arr_depth == 0) break; } + else if (*arr_p == '"' && strncmp(arr_p, "\"text\"", 6) == 0) { + arr_p += 6; + while (*arr_p == ' ' || *arr_p == ':' || *arr_p == '\t') arr_p++; + if (*arr_p == '"') { + arr_p++; + content_start = arr_p; + char *end = find_json_string_end(arr_p); + if (end) { + content_len = (int)(end - arr_p); + arr_p = end + 1; + } + break; + } + } + arr_p++; + } + obj_scan = arr_p; + while (*obj_scan && *obj_scan != ']') obj_scan++; + if (*obj_scan == ']') obj_scan++; + continue; + } + obj_scan++; + continue; + } + // Skip other string keys/values + obj_scan++; // skip opening quote + char *end = find_json_string_end(obj_scan); + if (end) obj_scan = end + 1; + else obj_scan++; + continue; + } + obj_scan++; + } + if (*obj_scan == '}') p = obj_scan + 1; + else p = obj_scan; + + // Build the chat template for this message + if (role_start && role_len > 0) { + // Make a null-terminated copy of role + char role[32] = {0}; + int rlen = role_len < 31 ? role_len : 31; + memcpy(role, role_start, rlen); + role[rlen] = '\0'; + + PROMPT_APPEND_STR("<|im_start|>"); + PROMPT_APPEND(role, strlen(role)); + PROMPT_APPEND_STR("\n"); + + if (content_start && content_len > 0) { + // Make a copy for unescaping + char *content_copy = malloc(content_len + 1); + memcpy(content_copy, content_start, content_len); + content_copy[content_len] = '\0'; + int unescaped_len = json_unescape_inplace(content_copy, content_len); + + PROMPT_APPEND(content_copy, unescaped_len); + + // Inject tools section after the first system message content + if (strcmp(role, "system") == 0 && !seen_system && tools_section) { + PROMPT_APPEND_STR(tools_section); + seen_system = 1; + } else if (strcmp(role, "system") == 0) { + seen_system = 1; + } + + free(content_copy); + } + + PROMPT_APPEND_STR("<|im_end|>\n"); + (void)0; // role tracking placeholder + } + continue; + + parse_fail: + free(buf); + return NULL; + } + + // If tools_section provided but no system message was seen, prepend a system message + if (tools_section && !seen_system) { + char *old_buf = buf; + size_t ts_len = strlen(tools_section); + size_t prefix_len = 13 + 7 + 1 + ts_len + 12 + 1; // <|im_start|>system\n + tools + <|im_end|>\n + size_t new_cap = prefix_len + buf_len + 1; + char *new_buf = malloc(new_cap); + if (!new_buf) { free(old_buf); return NULL; } + int off = snprintf(new_buf, new_cap, "<|im_start|>system\nYou are a helpful assistant.%s<|im_end|>\n", tools_section); + memcpy(new_buf + off, old_buf, buf_len + 1); + free(old_buf); + buf = new_buf; + buf_len = off + buf_len; + } + + // Always end with <|im_start|>assistant\n to prompt the model to respond + PROMPT_APPEND_STR("<|im_start|>assistant\n"); + + #undef PROMPT_APPEND + #undef PROMPT_APPEND_STR + + return buf; +} + +// Extract tool definitions from the "tools" array in the request body. +// Builds the Qwen-format tools section string for injection into the system prompt. +// Returns malloc'd string, or NULL if no tools array found. +static char *extract_tools_section(const char *body) { + const char *tp = strstr(body, "\"tools\""); + if (!tp) return NULL; + tp += 7; + while (*tp == ' ' || *tp == '\t' || *tp == ':' || *tp == '\n' || *tp == '\r') tp++; + if (*tp != '[') return NULL; + + // Find the matching closing bracket for the tools array + int depth = 1; + const char *arr_start = tp; // points to '[' + const char *scan = tp + 1; + while (*scan && depth > 0) { + if (*scan == '[' || *scan == '{') depth++; + else if (*scan == ']' || *scan == '}') depth--; + else if (*scan == '"') { + scan++; + while (*scan && !(*scan == '"' && *(scan-1) != '\\')) scan++; + } + scan++; + } + // scan now points past the closing ] + int arr_len = (int)(scan - arr_start); + + // Extract each tool object's raw JSON from within the array + // We need to find each {"type": "function", "function": {...}} object + // and emit it as-is inside tags + + size_t buf_cap = 4096; + size_t buf_len = 0; + char *buf = malloc(buf_cap); + if (!buf) return NULL; + + #define TOOLS_APPEND(s, slen) do { \ + while (buf_len + (slen) + 1 > buf_cap) { \ + buf_cap *= 2; \ + buf = realloc(buf, buf_cap); \ + if (!buf) return NULL; \ + } \ + memcpy(buf + buf_len, (s), (slen)); \ + buf_len += (slen); \ + buf[buf_len] = '\0'; \ + } while(0) + + #define TOOLS_APPEND_STR(s) do { size_t _l = strlen(s); TOOLS_APPEND((s), _l); } while(0) + + TOOLS_APPEND_STR( + "\n\n# Tools\n\n" + "You may call one or more functions to assist with the user query. " + "You are provided with function signatures within XML tags:\n" + "\n" + ); + + // Walk through the tools array and extract each tool object + char *wp = (char *)arr_start + 1; // skip [ + char *arr_end = (char *)arr_start + arr_len - 1; // points to ] + while (wp < arr_end) { + while (wp < arr_end && (*wp == ' ' || *wp == '\t' || *wp == '\n' || *wp == '\r' || *wp == ',')) wp++; + if (wp >= arr_end || *wp != '{') break; + + // Find the matching } for this tool object + char *obj_start = wp; + int od = 1; + wp++; + while (wp < arr_end && od > 0) { + if (*wp == '{') od++; + else if (*wp == '}') od--; + else if (*wp == '"') { + wp++; + while (wp < arr_end && !(*wp == '"' && *(wp-1) != '\\')) wp++; + } + wp++; + } + // wp points past the closing } + int obj_len = (int)(wp - obj_start); + + // Emit the raw JSON object (preserving original formatting) + TOOLS_APPEND(obj_start, obj_len); + TOOLS_APPEND_STR("\n"); + } + + TOOLS_APPEND_STR( + "\n\n" + "For each function call, return a json object with function name and arguments " + "within XML tags:\n" + "\n" + "{\"name\": \"function_name\", \"arguments\": {\"arg1\": \"value1\"}}\n" + "" + ); + + #undef TOOLS_APPEND + #undef TOOLS_APPEND_STR + + return buf; +} + // Extract "session_id" string from JSON body. Copies into out_buf (max out_size). // Returns 1 if found, 0 if missing. static int extract_session_id(const char *buf, char *out_buf, int out_size) { @@ -5602,6 +6160,162 @@ static void sse_send_done(int fd, const char *request_id) { http_write(fd, chunk, n); } +// Send SSE done with tool_calls finish reason (streaming mode) +static void sse_send_done_tool_calls(int fd, const char *request_id) { + char chunk[1024]; + int n = snprintf(chunk, sizeof(chunk), + "data: {\"id\":\"%s\",\"object\":\"chat.completion.chunk\"," + "\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n" + "data: [DONE]\n\n", + request_id); + http_write(fd, chunk, n); +} + +// JSON-escape a string into dst buffer. Returns number of bytes written. +static int json_escape_string(const char *src, int src_len, char *dst, int dst_cap) { + char *w = dst; + char *end = dst + dst_cap - 1; + for (int i = 0; i < src_len && w < end - 6; i++) { + switch (src[i]) { + case '"': *w++ = '\\'; *w++ = '"'; break; + case '\\': *w++ = '\\'; *w++ = '\\'; break; + case '\n': *w++ = '\\'; *w++ = 'n'; break; + case '\r': *w++ = '\\'; *w++ = 'r'; break; + case '\t': *w++ = '\\'; *w++ = 't'; break; + default: *w++ = src[i]; break; + } + } + *w = '\0'; + return (int)(w - dst); +} + +// Parse ... blocks from generated text. +// Extracts tool call name and arguments JSON for OpenAI-compatible responses. +// Returns number of tool calls found. Fills tool_calls_json (malloc'd) with JSON array content. +// Caller must free *tool_calls_json. +static int parse_tool_calls(const char *text, char **tool_calls_json) { + *tool_calls_json = NULL; + // Count occurrences + int count = 0; + const char *s = text; + while ((s = strstr(s, "")) != NULL) { count++; s += 11; } + if (count == 0) return 0; + + size_t buf_cap = 4096; + size_t buf_len = 0; + char *buf = malloc(buf_cap); + buf[0] = '\0'; + + #define TC_APPEND(str, slen) do { \ + while (buf_len + (slen) + 1 > buf_cap) { buf_cap *= 2; buf = realloc(buf, buf_cap); } \ + memcpy(buf + buf_len, (str), (slen)); buf_len += (slen); buf[buf_len] = '\0'; \ + } while(0) + #define TC_APPEND_STR(str) do { size_t _l = strlen(str); TC_APPEND(str, _l); } while(0) + + TC_APPEND_STR("["); + + s = text; + int idx = 0; + while ((s = strstr(s, "")) != NULL) { + s += 11; // skip + while (*s == '\n' || *s == ' ' || *s == '\r' || *s == '\t') s++; + const char *end = strstr(s, ""); + if (!end) break; + + // The content between tags should be a JSON object: {"name": "...", "arguments": {...}} + // Extract name and arguments + int block_len = (int)(end - s); + char *block = malloc(block_len + 1); + memcpy(block, s, block_len); + block[block_len] = '\0'; + + // Trim trailing whitespace + int bl = block_len; + while (bl > 0 && (block[bl-1] == '\n' || block[bl-1] == ' ' || block[bl-1] == '\r' || block[bl-1] == '\t')) bl--; + block[bl] = '\0'; + + // Extract "name" from the block + char *np = strstr(block, "\"name\""); + char *fn_name = NULL; int fn_name_len = 0; + if (np) { + np += 6; + while (*np == ' ' || *np == ':' || *np == '\t') np++; + if (*np == '"') { + np++; + fn_name = np; + char *ne = find_json_string_end(np); + if (ne) fn_name_len = (int)(ne - np); + } + } + + // Extract "arguments" — could be object or string + char *ap = strstr(block, "\"arguments\""); + char *args_start = NULL; int args_len = 0; + if (ap) { + ap += 11; + while (*ap == ' ' || *ap == ':' || *ap == '\t') ap++; + if (*ap == '{') { + // Find matching } + args_start = ap; + int d = 1; ap++; + while (*ap && d > 0) { + if (*ap == '{') d++; + else if (*ap == '}') d--; + else if (*ap == '"') { ap++; while (*ap && !(*ap == '"' && *(ap-1) != '\\')) ap++; } + ap++; + } + args_len = (int)(ap - args_start); + } else if (*ap == '"') { + // String arguments — pass through as-is + ap++; + args_start = ap - 1; // include the quote + char *ae = find_json_string_end(ap); + if (ae) args_len = (int)(ae - args_start + 1); + } + } + + if (idx > 0) TC_APPEND_STR(","); + + // Build: {"index":N,"id":"call_N","type":"function","function":{"name":"...","arguments":"..."}} + char tc_buf[8192]; + char escaped_args[4096]; + if (args_start && args_len > 0 && *args_start == '{') { + // Escape the args JSON as a string value + json_escape_string(args_start, args_len, escaped_args, sizeof(escaped_args)); + } else if (args_start && args_len > 0) { + memcpy(escaped_args, args_start, args_len < 4095 ? args_len : 4095); + escaped_args[args_len < 4095 ? args_len : 4095] = '\0'; + } else { + strcpy(escaped_args, "{}"); + } + + char name_buf[256] = {0}; + if (fn_name && fn_name_len > 0) { + int nl = fn_name_len < 255 ? fn_name_len : 255; + memcpy(name_buf, fn_name, nl); + name_buf[nl] = '\0'; + } + + int tc_len = snprintf(tc_buf, sizeof(tc_buf), + "{\"index\":%d,\"id\":\"call_%d\",\"type\":\"function\"," + "\"function\":{\"name\":\"%s\",\"arguments\":\"%s\"}}", + idx, idx, name_buf, escaped_args); + TC_APPEND(tc_buf, tc_len); + + free(block); + s = end + 12; // skip + idx++; + } + + TC_APPEND_STR("]"); + + #undef TC_APPEND + #undef TC_APPEND_STR + + *tool_calls_json = buf; + return idx; +} + static const char *SSE_HEADERS = "HTTP/1.1 200 OK\r\n" "Content-Type: text/event-stream\r\n" @@ -5760,11 +6474,19 @@ static void serve_loop( } printf("[serve] Listening on http://0.0.0.0:%d\n", port); - printf("[serve] Endpoints: POST /v1/chat/completions, GET /v1/models, GET /health\n"); + printf("[serve] Endpoints: POST /v1/chat/completions, GET /v1/models, GET /health, GET /stats\n"); fflush(stdout); static uint64_t req_counter = 0; + // Initialize dashboard stats + g_serve_start_time = now_ms(); + g_serve_port = port; + g_serve_K = K; + g_serve_state = 0; + g_total_requests = 0; + write_stats_file(); + // ---- System prompt cache: prefill system prompt once at startup ---- // Tokenize the system prompt and run it through all 60 layers. // Save the resulting KV cache + linear attention state as a snapshot. @@ -5895,6 +6617,20 @@ static void serve_loop( int session_pos = 0; // RoPE position after last generation for the active session for (;;) { + // Use select() with timeout so we can update stats file while idle + fd_set readfds; + FD_ZERO(&readfds); + FD_SET(server_fd, &readfds); + struct timeval tv = { .tv_sec = 5, .tv_usec = 0 }; + int sel = select(server_fd + 1, &readfds, NULL, NULL, &tv); + if (sel == 0) { + // Timeout — update stats file to keep dashboard alive + g_elapsed_ms = 0; + write_stats_file(); + continue; + } + if (sel < 0) { perror("select"); continue; } + struct sockaddr_in client_addr; socklen_t client_len = sizeof(client_addr); int client_fd = accept(server_fd, (struct sockaddr *)&client_addr, &client_len); @@ -5930,6 +6666,38 @@ static void serve_loop( continue; } + // GET /stats + if (strcmp(method, "GET") == 0 && strcmp(path, "/stats") == 0) { + char stats_json[2048]; + double uptime = g_serve_start_time > 0 ? (now_ms() - g_serve_start_time) / 1000.0 : 0; + int sn = snprintf(stats_json, sizeof(stats_json), + "{\"state\":\"%s\",\"request_id\":\"%s\"," + "\"prefill_tokens\":%d,\"prefill_done\":%d," + "\"gen_tokens\":%d,\"gen_max\":%d," + "\"tok_per_sec\":%.2f,\"elapsed_ms\":%.1f,\"ttft_ms\":%.1f," + "\"think_tokens\":%d,\"total_requests\":%d,\"uptime_s\":%.1f," + "\"model\":\"qwen3.5-397b-a17b\",\"quant\":\"%s\",\"k\":%d,\"port\":%d}\n", + g_serve_state == 0 ? "idle" : g_serve_state == 1 ? "prefilling" : "generating", + g_current_request_id, + g_prefill_total, g_prefill_done, + g_gen_tokens, g_gen_max, + g_tok_per_sec, g_elapsed_ms, g_ttft_ms, + g_think_tokens_serve, g_total_requests, uptime, + g_use_2bit ? "2-bit" : "4-bit", K, port); + char stats_hdr[512]; + int sh = snprintf(stats_hdr, sizeof(stats_hdr), + "HTTP/1.1 200 OK\r\n" + "Content-Type: application/json\r\n" + "Content-Length: %d\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Connection: close\r\n" + "\r\n", sn); + http_write(client_fd, stats_hdr, sh); + http_write(client_fd, stats_json, sn); + free(reqbuf); close(client_fd); + continue; + } + // GET /v1/models if (strcmp(method, "GET") == 0 && strcmp(path, "/v1/models") == 0) { const char *resp = @@ -5957,40 +6725,71 @@ static void serve_loop( } body += 4; - // Extract session_id and max_tokens BEFORE content extraction - // (extract_last_content mutates the body buffer in place) + // Extract parameters from body BEFORE any mutation int max_gen = extract_max_tokens(body, 8192); if (max_gen > 32768) max_gen = 32768; + int stream = extract_stream_flag(body); char req_session_id[64] = {0}; int has_session = extract_session_id(body, req_session_id, sizeof(req_session_id)); - // Extract user content from messages (mutates body — must be last) - char *content = extract_last_content(body); - if (!content || strlen(content) == 0) { - http_write_str(client_fd, - "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n" - "{\"error\":\"no content in messages\"}\n"); - free(reqbuf); close(client_fd); continue; - } - int is_continuation = (has_session && - active_session_id[0] != '\0' && - strcmp(req_session_id, active_session_id) == 0); + // Extract tools section (if tools array present in request) + char *tools_section = extract_tools_section(body); - // Session persistence is handled by the client (chat.m) + // Determine request mode: + // - If session_id is present: chat.m client, use session/continuation mechanism + // - If no session_id: OpenClaw/generic client, parse full messages array + int is_continuation = 0; + char *full_prompt = NULL; + char *content = NULL; + + if (has_session) { + // ---- chat.m client path: extract last content, use session continuations ---- + is_continuation = (active_session_id[0] != '\0' && + strcmp(req_session_id, active_session_id) == 0); + content = extract_last_content(body); + if (!content || strlen(content) == 0) { + http_write_str(client_fd, + "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n" + "{\"error\":\"no content in messages\"}\n"); + if (tools_section) free(tools_section); + free(reqbuf); close(client_fd); continue; + } + } else { + // ---- OpenClaw / generic client path: parse full messages array ---- + full_prompt = extract_messages_to_prompt(body, tools_section); + if (!full_prompt || strlen(full_prompt) == 0) { + // Fallback: try extract_last_content for simple requests + content = extract_last_content(body); + if (!content || strlen(content) == 0) { + http_write_str(client_fd, + "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n" + "{\"error\":\"no content in messages\"}\n"); + if (full_prompt) free(full_prompt); + if (tools_section) free(tools_section); + free(reqbuf); close(client_fd); continue; + } + } + } char request_id[64]; snprintf(request_id, sizeof(request_id), "chatcmpl-%llu", ++req_counter); - fprintf(stderr, "[serve] %s content=%zu chars, max_tokens=%d, session=%s%s\n", - request_id, strlen(content), max_gen, - has_session ? req_session_id : "(none)", - is_continuation ? " [CONTINUE]" : " [NEW]"); + if (full_prompt) { + fprintf(stderr, "[serve] %s full_prompt=%zu chars, max_tokens=%d, stream=%d [FULL_MESSAGES]\n", + request_id, strlen(full_prompt), max_gen, stream); + } else { + fprintf(stderr, "[serve] %s content=%zu chars, max_tokens=%d, session=%s%s\n", + request_id, strlen(content), max_gen, + has_session ? req_session_id : "(none)", + is_continuation ? " [CONTINUE]" : " [NEW]"); + } // ---- Tokenize ---- - // Continuation: prefix with <|im_end|>\n to close prior assistant turn - // New session: just the user turn (system prompt restored from snapshot) PromptTokens *pt; - if (is_continuation) { + if (full_prompt) { + // Full messages mode: tokenize the entire assembled prompt + pt = encode_prompt_text_to_tokens(full_prompt); + } else if (is_continuation) { pt = tokenize_continuation_turn(content); } else { pt = tokenize_user_turn(content); @@ -5999,29 +6798,56 @@ static void serve_loop( http_write_str(client_fd, "HTTP/1.1 500 Internal Server Error\r\nConnection: close\r\n\r\n" "{\"error\":\"tokenization failed\"}\n"); + if (full_prompt) free(full_prompt); + if (tools_section) free(tools_section); free(reqbuf); close(client_fd); continue; } fprintf(stderr, "[serve] %s prompt=%d tokens%s\n", request_id, pt->count, + full_prompt ? " (full messages — reset all state)" : is_continuation ? " (continuation — skipping snapshot restore)" : ""); int pos; - if (is_continuation) { + if (full_prompt) { + // ---- Full messages mode: reset ALL state to zero ---- + // OpenClaw sends the entire conversation history, so we prefill from scratch. + // No snapshot restore — the client provides its own system prompt. + for (int i = 0; i < NUM_LAYERS; i++) { + if (kv_caches[i]) { + // Zero out KV caches completely + kv_caches[i]->len = 0; + // Also zero GPU KV mirror + if (g_metal) { + int fa_idx = (i + 1) / FULL_ATTN_INTERVAL - 1; + if (fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS) { + size_t kv_sz = GPU_KV_SEQ * kv_dim * sizeof(float); + memset([g_metal->buf_kv_k[fa_idx] contents], 0, kv_sz); + memset([g_metal->buf_kv_v[fa_idx] contents], 0, kv_sz); + } + } + } + if (layer_states[i]) { + LinearAttnState *s = (LinearAttnState *)layer_states[i]; + memset(s->conv_state, 0, conv_state_size); + memset(s->ssm_state, 0, ssm_state_size); + } + } + // Zero GPU delta-net state + reset_delta_net_state(); + pos = 0; // start from position 0 — no cached system prompt + // Invalidate any active session + active_session_id[0] = '\0'; + } else if (is_continuation) { // ---- Continue from existing session state ---- - // The KV caches + linear attention state already contain the full - // conversation history. Just set pos to where we left off. pos = session_pos; } else { // ---- Restore state from system prompt snapshot ---- - // Instead of resetting to zero, restore to the cached system prompt state. - // This skips re-prefilling the system prompt tokens (~20 tokens, ~6s saved). for (int i = 0; i < NUM_LAYERS; i++) { if (kv_caches[i] && kv_snapshots[i].k_snapshot) { size_t sz = sys_prompt_len * kv_dim * sizeof(float); memcpy(kv_caches[i]->k_cache, kv_snapshots[i].k_snapshot, sz); memcpy(kv_caches[i]->v_cache, kv_snapshots[i].v_snapshot, sz); kv_caches[i]->len = kv_snapshots[i].len; - // Also restore GPU KV mirror if (g_metal) { int fa_idx = (i + 1) / FULL_ATTN_INTERVAL - 1; if (fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS) { @@ -6044,7 +6870,6 @@ static void serve_loop( memset(s->ssm_state, 0, ssm_state_size); } } - // Restore GPU delta-net state if (g_metal && g_metal->delta_net_step) { for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { if (gpu_delta_snapshots[i] && g_metal->buf_delta_state[i]) @@ -6057,8 +6882,7 @@ static void serve_loop( } else { reset_delta_net_state(); } - pos = sys_prompt_len; // start after cached system prompt - // Update active session + pos = sys_prompt_len; if (has_session) { strncpy(active_session_id, req_session_id, sizeof(active_session_id) - 1); active_session_id[sizeof(active_session_id) - 1] = '\0'; @@ -6068,12 +6892,29 @@ static void serve_loop( } if (g_cache_telemetry_enabled) cache_telemetry_reset(); - // ---- Send SSE headers ---- - http_write_str(client_fd, SSE_HEADERS); + // ---- Send response headers ---- + if (stream) { + http_write_str(client_fd, SSE_HEADERS); + } - // ---- Batch prefill ---- + // ---- Update dashboard stats for this request ---- + g_total_requests++; + g_serve_state = 1; // prefilling + g_prefill_total = pt->count; + g_prefill_done = 0; + g_gen_tokens = 0; + g_gen_max = max_gen; + g_tok_per_sec = 0; + g_elapsed_ms = 0; + g_ttft_ms = 0; + g_think_tokens_serve = 0; + strncpy(g_current_request_id, request_id, sizeof(g_current_request_id) - 1); + g_current_request_id[sizeof(g_current_request_id) - 1] = '\0'; + write_stats_file(); + + // ---- Prefill: token-at-a-time with deferred GPU pipeline ---- double t_prefill = now_ms(); - // Pre-embed all request tokens + // Pre-embed all tokens float *serve_embed_batch = NULL; if (pt->count > 1) { serve_embed_batch = malloc((size_t)pt->count * HIDDEN_DIM * sizeof(float)); @@ -6081,9 +6922,9 @@ static void serve_loop( embed_lookup(wf, pt->ids[i], serve_embed_batch + (size_t)i * HIDDEN_DIM); } } - // Intermediate prefill tokens: discard last-layer expert output - for (int i = 0; i < pt->count - 1; i++) { - cache_telemetry_note_token(); + // Prefill tokens one at a time through all 60 layers + // (preserves CMD3→CMD1 deferred GPU pipeline overlap for ~15 tok/s) + for (int i = 0; i < pt->count; i++) { if (serve_embed_batch) { memcpy(hidden, serve_embed_batch + (size_t)i * HIDDEN_DIM, HIDDEN_DIM * sizeof(float)); @@ -6099,35 +6940,31 @@ static void serve_loop( layer_mmaps[layer] != MAP_FAILED ? layer_mmaps[layer] : NULL, K, layer_fds[layer]); } - discard_deferred_experts(); - pos++; - } - // Last prefill token: full completion (need hidden for logits) - { - cache_telemetry_note_token(); - if (serve_embed_batch) { - memcpy(hidden, serve_embed_batch + (size_t)(pt->count - 1) * HIDDEN_DIM, - HIDDEN_DIM * sizeof(float)); + if (i < pt->count - 1) { + discard_deferred_experts(); } else { - embed_lookup(wf, pt->ids[0], hidden); - } - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); - fused_layer_forward(wf, layer, hidden, - is_full ? kv_caches[layer] : NULL, - is_full ? NULL : layer_states[layer], - pos, - layer_mmaps[layer] != MAP_FAILED ? layer_mmaps[layer] : NULL, - K, layer_fds[layer]); + complete_deferred_experts(); } - complete_deferred_experts(); pos++; + + // Update dashboard stats + g_prefill_done = i + 1; + if ((i + 1) % 10 == 0 || i == pt->count - 1) { + g_elapsed_ms = now_ms() - t_prefill; + write_stats_file(); + } } if (serve_embed_batch) { free(serve_embed_batch); serve_embed_batch = NULL; } double prefill_ms = now_ms() - t_prefill; fprintf(stderr, "[serve] %s prefill=%d tokens in %.0fms\n", request_id, pt->count, prefill_ms); + // Update dashboard: prefill complete, switch to generating + g_prefill_done = pt->count; + g_ttft_ms = prefill_ms; + g_serve_state = 2; // generating + write_stats_file(); + // ---- Final norm + LM head for first token ---- if (final_norm_w) { float *normed = malloc(HIDDEN_DIM * sizeof(float)); @@ -6138,13 +6975,14 @@ static void serve_loop( lm_head_forward(wf, hidden, logits); int next_token = cpu_argmax(logits, VOCAB_SIZE); - // ---- Auto-regressive generation with SSE streaming ---- + // ---- Auto-regressive generation ---- double t_gen = now_ms(); int gen_count = 0; int in_think = 0; int think_tokens = 0; - // Accumulate response for session persistence - char *gen_response = calloc(1, 256 * 1024); + // Accumulate full response (needed for non-streaming + tool call detection) + size_t gen_resp_cap = 256 * 1024; + char *gen_response = calloc(1, gen_resp_cap); int gen_resp_len = 0; for (int gen = 0; gen < max_gen; gen++) { @@ -6172,25 +7010,43 @@ static void serve_loop( if (in_think) { think_tokens++; if (g_think_budget > 0 && think_tokens >= g_think_budget) { - next_token = THINK_END_TOKEN; // force end thinking + next_token = THINK_END_TOKEN; in_think = 0; } } const char *tok_str = decode_token(vocab, next_token); - // Accumulate non-thinking response for session persistence - if (!in_think && tok_str && gen_resp_len + (int)strlen(tok_str) < 256*1024 - 1) { + // Accumulate response text (always, for tool call detection) + if (tok_str) { int tlen = (int)strlen(tok_str); + // Grow buffer if needed + while (gen_resp_len + tlen + 1 > (int)gen_resp_cap) { + gen_resp_cap *= 2; + gen_response = realloc(gen_response, gen_resp_cap); + } memcpy(gen_response + gen_resp_len, tok_str, tlen); gen_resp_len += tlen; gen_response[gen_resp_len] = 0; } - if (sse_send_delta(client_fd, request_id, tok_str) < 0) { - fprintf(stderr, "[serve] %s client disconnected, stopping generation\n", request_id); - break; + // Stream token if in streaming mode + if (stream) { + if (sse_send_delta(client_fd, request_id, tok_str) < 0) { + fprintf(stderr, "[serve] %s client disconnected, stopping generation\n", request_id); + break; + } } gen_count++; + // Update dashboard stats + g_gen_tokens = gen_count; + g_think_tokens_serve = think_tokens; + g_elapsed_ms = now_ms() - t_prefill; + if (gen_count > 0) { + double gen_elapsed = now_ms() - t_gen; + g_tok_per_sec = gen_count * 1000.0 / gen_elapsed; + } + write_stats_file(); + // Generate next cache_telemetry_note_token(); embed_lookup(wf, next_token, hidden); @@ -6216,12 +7072,74 @@ static void serve_loop( next_token = cpu_argmax(logits, VOCAB_SIZE); } - sse_send_done(client_fd, request_id); + // ---- Detect tool calls in generated output ---- + char *tool_calls_json = NULL; + int num_tool_calls = parse_tool_calls(gen_response, &tool_calls_json); + const char *finish_reason = num_tool_calls > 0 ? "tool_calls" : "stop"; + + if (stream) { + // ---- Streaming response: send finish event ---- + if (num_tool_calls > 0) { + sse_send_done_tool_calls(client_fd, request_id); + } else { + sse_send_done(client_fd, request_id); + } + } else { + // ---- Non-streaming response: send complete JSON ---- + // JSON-escape the full response content + size_t esc_cap = (size_t)gen_resp_len * 2 + 64; + char *escaped_content = malloc(esc_cap); + int esc_len = json_escape_string(gen_response, gen_resp_len, escaped_content, (int)esc_cap); + + // Build the response JSON + size_t resp_cap = esc_len + 1024; + if (tool_calls_json) resp_cap += strlen(tool_calls_json) + 256; + char *resp_body = malloc(resp_cap); + + int resp_len; + if (num_tool_calls > 0) { + resp_len = snprintf(resp_body, resp_cap, + "{\"id\":\"%s\",\"object\":\"chat.completion\"," + "\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\"," + "\"content\":\"%s\",\"tool_calls\":%s}," + "\"finish_reason\":\"tool_calls\"}]," + "\"usage\":{\"prompt_tokens\":%d,\"completion_tokens\":%d,\"total_tokens\":%d}}", + request_id, escaped_content, tool_calls_json, + pt->count, gen_count, pt->count + gen_count); + } else { + resp_len = snprintf(resp_body, resp_cap, + "{\"id\":\"%s\",\"object\":\"chat.completion\"," + "\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\"," + "\"content\":\"%s\"},\"finish_reason\":\"stop\"}]," + "\"usage\":{\"prompt_tokens\":%d,\"completion_tokens\":%d,\"total_tokens\":%d}}", + request_id, escaped_content, + pt->count, gen_count, pt->count + gen_count); + } + + // Send HTTP response with proper headers + char hdr[512]; + int hdr_len = snprintf(hdr, sizeof(hdr), + "HTTP/1.1 200 OK\r\n" + "Content-Type: application/json\r\n" + "Content-Length: %d\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Connection: close\r\n" + "\r\n", resp_len); + http_write(client_fd, hdr, hdr_len); + http_write(client_fd, resp_body, resp_len); + + free(escaped_content); + free(resp_body); + } + + if (tool_calls_json) free(tool_calls_json); + + fprintf(stderr, "[serve] %s finish_reason=%s%s\n", + request_id, finish_reason, + num_tool_calls > 0 ? "" : ""); // ---- Save session state ---- free(gen_response); - // The KV caches + linear attention state already contain this conversation. - // Just record the position so the next request can continue from here. session_pos = pos; fprintf(stderr, "[serve] %s session_pos=%d (session=%s)\n", request_id, session_pos, @@ -6237,6 +7155,13 @@ static void serve_loop( cache_telemetry_print(g_malloc_cache->hits, g_malloc_cache->misses); } + // Update dashboard: back to idle + g_serve_state = 0; + g_elapsed_ms = now_ms() - t_prefill; + write_stats_file(); + + if (full_prompt) free(full_prompt); + if (tools_section) free(tools_section); free(pt->ids); free(pt); free(reqbuf); @@ -6586,78 +7511,65 @@ int main(int argc, char **argv) { printf("--- Generating %d tokens ---\n", max_tokens); int pos = 0; // position counter for RoPE - // ---- Batch prefill: pre-embed all prompt tokens ---- - // Embedding all tokens upfront into a batch buffer avoids interleaving - // embed_lookup with GPU work, and enables the optimized prefill loop below. - float *embed_batch = NULL; + // ---- Batch prefill (batched projections + per-layer expert caching) ---- if (pt->count > 1) { - embed_batch = malloc((size_t)pt->count * HIDDEN_DIM * sizeof(float)); + float *embed_batch = malloc((size_t)pt->count * HIDDEN_DIM * sizeof(float)); double t_embed = now_ms(); for (int i = 0; i < pt->count; i++) { embed_lookup(wf, pt->ids[i], embed_batch + (size_t)i * HIDDEN_DIM); } double embed_ms = now_ms() - t_embed; printf(" [prefill] batch embed %d tokens: %.1f ms\n", pt->count, embed_ms); - } - - // ---- Batch prefill loop ---- - // Process all prompt tokens through the model. For intermediate tokens - // (not the last), we use discard_deferred_experts() which waits for the GPU - // but skips the CPU readback/combine of the last layer's expert outputs. - // This is safe because the hidden state from intermediate prefill tokens - // is immediately overwritten by the next token's embedding — the recurrent - // state (KV cache, delta-net state) is already updated inside fused_layer_forward. - if (pt->count > 1) { - double t_prefill_batch = now_ms(); - double first_tok_ms = 0; - for (int token_idx = 0; token_idx < pt->count - 1; token_idx++) { - double t_tok = now_ms(); + // Ensure no deferred state is active + if (g_deferred.active) { + wait_deferred_experts_gpu(); + g_deferred.active = 0; + g_deferred.gpu_combined = 0; + g_deferred.cmd_experts = nil; + } - // Load pre-embedded token from batch buffer - cache_telemetry_note_token(); - memcpy(hidden, embed_batch + (size_t)token_idx * HIDDEN_DIM, + double t_prefill_batch = now_ms(); + int N = pt->count; + + // Token-at-a-time prefill with deferred GPU pipeline. + // Each token passes through all 60 layers, benefiting from: + // - CMD3->CMD1 deferred overlap (GPU experts run while next layer starts) + // - CMD1+CMD2 merge for linear attention layers (eliminates one dispatch) + for (int i = 0; i < N; i++) { + memcpy(hidden, embed_batch + (size_t)i * HIDDEN_DIM, HIDDEN_DIM * sizeof(float)); - // Run through all 60 transformer layers for (int layer = 0; layer < NUM_LAYERS; layer++) { int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], - pos, + pos + i, layer_mmaps[layer] != MAP_FAILED ? layer_mmaps[layer] : NULL, K, layer_fds[layer]); } - - // Discard last layer's expert output — hidden will be overwritten - // by the next token's embedding. Only wait for GPU (buffer safety). - discard_deferred_experts(); - pos++; - - if (token_idx == 0) { - first_tok_ms = now_ms() - t_tok; + if (i < N - 1) { + discard_deferred_experts(); + } else { + complete_deferred_experts(); } - } - - double prefill_batch_ms = now_ms() - t_prefill_batch; - double avg_ms = (pt->count > 2) ? - (prefill_batch_ms - first_tok_ms) / (pt->count - 2) : first_tok_ms; - printf(" [prefill] %d/%d tokens: %.0f ms (first: %.0f ms, rest avg: %.0f ms)\n", - pt->count - 1, pt->count, prefill_batch_ms, first_tok_ms, avg_ms); - } - // ---- Last prefill token (or single-token prompt) ---- - // This one needs full completion since we need hidden state for logits. - { - cache_telemetry_note_token(); - if (embed_batch) { - memcpy(hidden, embed_batch + (size_t)(pt->count - 1) * HIDDEN_DIM, + memcpy(embed_batch + (size_t)i * HIDDEN_DIM, hidden, HIDDEN_DIM * sizeof(float)); - } else { - embed_lookup(wf, pt->ids[0], hidden); } + memcpy(hidden, embed_batch + (size_t)(N - 1) * HIDDEN_DIM, + HIDDEN_DIM * sizeof(float)); + pos += N; + + double prefill_batch_ms = now_ms() - t_prefill_batch; + printf(" [prefill] %d tokens: %.0f ms (%.1f ms/tok)\n", + N, prefill_batch_ms, prefill_batch_ms / N); + free(embed_batch); + } else { + // ---- Single-token prompt ---- + embed_lookup(wf, pt->ids[0], hidden); for (int layer = 0; layer < NUM_LAYERS; layer++) { int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); fused_layer_forward(wf, layer, hidden, @@ -6667,13 +7579,10 @@ int main(int argc, char **argv) { layer_mmaps[layer] != MAP_FAILED ? layer_mmaps[layer] : NULL, K, layer_fds[layer]); } - // Full completion — need hidden state for final norm + lm_head complete_deferred_experts(); pos++; } - if (embed_batch) { free(embed_batch); embed_batch = NULL; } - // ---- Final norm ---- if (final_norm_w) { float *normed = malloc(HIDDEN_DIM * sizeof(float));