diff --git a/.gitignore b/.gitignore index cf2c548b..d2e34910 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ *.dSYM/ /misc/ .*.swp +.DS_Store diff --git a/Makefile b/Makefile index 5f975636..3e933d54 100644 --- a/Makefile +++ b/Makefile @@ -9,18 +9,18 @@ METAL_SRCS := $(wildcard metal/*.metal) ifeq ($(UNAME_S),Darwin) METAL_LDLIBS := $(LDLIBS) -framework Foundation -framework Metal -CORE_OBJS = ds4.o ds4_metal.o -NATIVE_CORE_OBJS = ds4_native.o +CORE_OBJS = ds4.o ds4_metal.o ds4_rpc.o +NATIVE_CORE_OBJS = ds4_native.o ds4_rpc.o else CFLAGS += -DDS4_NO_METAL -CORE_OBJS = ds4.o -NATIVE_CORE_OBJS = ds4_native.o +CORE_OBJS = ds4.o ds4_rpc.o +NATIVE_CORE_OBJS = ds4_native.o ds4_rpc.o METAL_LDLIBS := $(LDLIBS) endif .PHONY: all clean test -all: ds4 ds4-server +all: ds4 ds4-server ds4-rpc-worker ifeq ($(UNAME_S),Darwin) ds4: ds4_cli.o linenoise.o $(CORE_OBJS) @@ -29,6 +29,9 @@ ds4: ds4_cli.o linenoise.o $(CORE_OBJS) ds4-server: ds4_server.o rax.o $(CORE_OBJS) $(CC) $(CFLAGS) -o $@ ds4_server.o rax.o $(CORE_OBJS) $(METAL_LDLIBS) +ds4-rpc-worker: ds4_rpc_worker.o $(CORE_OBJS) + $(CC) $(CFLAGS) -o $@ ds4_rpc_worker.o $(CORE_OBJS) $(METAL_LDLIBS) + ds4_native: ds4_cli_native.o linenoise.o $(NATIVE_CORE_OBJS) $(CC) $(CFLAGS) -o $@ ds4_cli_native.o linenoise.o $(NATIVE_CORE_OBJS) $(NATIVE_LDLIBS) else @@ -38,6 +41,9 @@ ds4: ds4_cli.o linenoise.o $(CORE_OBJS) ds4-server: ds4_server.o rax.o $(CORE_OBJS) $(CC) $(CFLAGS) -o $@ $^ $(LDLIBS) +ds4-rpc-worker: ds4_rpc_worker.o $(CORE_OBJS) + $(CC) $(CFLAGS) -o $@ $^ $(LDLIBS) + ds4_native: ds4_cli_native.o linenoise.o $(NATIVE_CORE_OBJS) $(CC) $(CFLAGS) -o $@ ds4_cli_native.o linenoise.o $(NATIVE_CORE_OBJS) $(LDLIBS) endif @@ -51,6 +57,12 @@ ds4_cli.o: ds4_cli.c ds4.h linenoise.h ds4_server.o: ds4_server.c ds4.h rax.h $(CC) $(CFLAGS) -c -o $@ ds4_server.c +ds4_rpc.o: ds4_rpc.c ds4_rpc.h ds4.h + $(CC) $(CFLAGS) -c -o $@ ds4_rpc.c + +ds4_rpc_worker.o: ds4_rpc_worker.c ds4_rpc.h ds4.h + $(CC) $(CFLAGS) -c -o $@ ds4_rpc_worker.c + ds4_test.o: tests/ds4_test.c ds4_server.c ds4.h rax.h $(CC) $(CFLAGS) -Wno-unused-function -c -o $@ tests/ds4_test.c @@ -76,4 +88,4 @@ test: ds4_test ./ds4_test clean: - rm -f ds4 ds4-server ds4_native ds4_server_test ds4_test *.o + rm -f ds4 ds4-server ds4-rpc-worker ds4_native ds4_server_test ds4_test *.o diff --git a/bench-ds4.py b/bench-ds4.py new file mode 100644 index 00000000..8d6bb34b --- /dev/null +++ b/bench-ds4.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +"""Measure ds4-server prefill and decode rates. + +Streams the OpenAI-compatible chat completion API to separate prefill +(time-to-first-token) from decode (steady-state generation), so the same +methodology covers single-host Q2, single-host Q4 (on a >=256 GB machine), +and pipeline-parallel Q4 over RPC. Whatever ds4-server is running, this +measures what it actually produces over the wire -- no special hooks. + +Usage: + python3 bench-ds4.py [--url URL] [--long PATH] [--max-tokens N] + [--no-thinking | --thinking] [--runs N] + +Examples: + # Default: localhost:8000, short prompt only, thinking disabled. + python3 bench-ds4.py + + # With the long-context test the README uses (11709 tokens): + python3 bench-ds4.py --long tests/long_context_security_prompt.txt + + # Average over 3 runs (cold first run is reported separately): + python3 bench-ds4.py --runs 3 --long tests/long_context_security_prompt.txt + +The same prompt is reused across runs in --runs mode so the rendered-prefix +cache absorbs setup overhead; the first run is reported as "cold" and the +remaining runs are averaged as "warm". +""" + +import argparse +import json +import sys +import time +import urllib.request + + +def stream_request(url: str, prompt: str, max_tokens: int, thinking_enabled: bool): + """One streamed chat completion. Returns timing + token counts.""" + body = { + "model": "deepseek-v4-flash", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "temperature": 0, + "stream": True, + "stream_options": {"include_usage": True}, + } + if not thinking_enabled: + body["thinking"] = {"type": "disabled"} + + req = urllib.request.Request( + url, + data=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, + method="POST", + ) + + t_start = time.perf_counter() + t_first = None + t_last = None + prompt_tokens = 0 + completion_tokens = 0 + reasoning_tokens = 0 + + with urllib.request.urlopen(req, timeout=3600) as resp: + for raw in resp: + line = raw.decode("utf-8", errors="replace").rstrip() + if not line.startswith("data: "): + continue + payload = line[6:] + if payload == "[DONE]": + continue + try: + obj = json.loads(payload) + except json.JSONDecodeError: + continue + + choices = obj.get("choices") or [] + if choices: + delta = choices[0].get("delta") or {} + if delta.get("content") is not None or delta.get("reasoning_content") is not None: + now = time.perf_counter() + if t_first is None: + t_first = now + t_last = now + if delta.get("reasoning_content") is not None: + reasoning_tokens += 1 + + usage = obj.get("usage") + if usage: + prompt_tokens = usage.get("prompt_tokens", prompt_tokens) + completion_tokens = usage.get("completion_tokens", completion_tokens) + + t_end = time.perf_counter() + return { + "t_total": t_end - t_start, + "t_prefill": (t_first - t_start) if t_first is not None else None, + "t_decode": (t_last - t_first) if (t_first is not None and t_last is not None) else None, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "reasoning_tokens": reasoning_tokens, + } + + +def fmt_rate(tokens, seconds): + if seconds is None or seconds <= 0 or tokens <= 0: + return "n/a" + return f"{tokens / seconds:.1f} t/s" + + +def print_run(label, r): + print(f"[{label}]") + print(f" prompt : {r['prompt_tokens']} tokens") + print(f" completion : {r['completion_tokens']} tokens" + + (f" ({r['reasoning_tokens']} reasoning)" if r['reasoning_tokens'] else "")) + if r["t_prefill"] is not None: + print(f" prefill : {r['t_prefill']:6.2f} s " + f"({fmt_rate(r['prompt_tokens'], r['t_prefill'])})") + else: + print(" prefill : n/a (no streamed token observed)") + if r["t_decode"] is not None and r["completion_tokens"] > 1: + # Decode rate excludes the first token (which is part of prefill latency). + print(f" decode : {r['t_decode']:6.2f} s " + f"({fmt_rate(r['completion_tokens'] - 1, r['t_decode'])})") + else: + print(" decode : n/a") + print(f" wall total : {r['t_total']:6.2f} s") + print() + + +def main(): + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--url", default="http://127.0.0.1:8000/v1/chat/completions") + ap.add_argument("--short", default=( + "Scrivi una breve storia su un gatto che impara a programmare in C. " + "Mantieni la storia entro tre paragrafi." + ), help="Short-prompt text (default: small Italian story prompt).") + ap.add_argument("--long", help="Path to a long-context prompt file.") + ap.add_argument("--max-tokens", type=int, default=256) + ap.add_argument("--thinking", action="store_true", + help="Leave thinking mode enabled (default: disabled for " + "predictable decode rate).") + ap.add_argument("--runs", type=int, default=1, + help="Repeat each prompt N times; first is reported cold, " + "remaining are averaged warm.") + args = ap.parse_args() + + long_prompt = None + if args.long: + with open(args.long, "r", encoding="utf-8") as f: + long_prompt = f.read() + + cases = [("short", args.short)] + if long_prompt is not None: + cases.append((f"long ({len(long_prompt)} chars)", long_prompt)) + + for case_label, prompt in cases: + results = [] + for i in range(args.runs): + try: + r = stream_request(args.url, prompt, args.max_tokens, args.thinking) + except Exception as e: + print(f"[{case_label} run {i+1}] failed: {e}") + sys.exit(1) + results.append(r) + tag = "cold" if i == 0 and args.runs > 1 else f"run {i+1}" + print_run(f"{case_label} {tag}", r) + + if args.runs > 1: + warm = results[1:] + warm_prefill = [r["t_prefill"] for r in warm if r["t_prefill"]] + warm_decode = [(r["completion_tokens"] - 1, r["t_decode"]) for r in warm + if r["t_decode"] and r["completion_tokens"] > 1] + if warm_prefill and results[0]["prompt_tokens"]: + avg_prefill = sum(warm_prefill) / len(warm_prefill) + print(f"[{case_label} warm avg over {len(warm)} run(s)]") + print(f" prefill : {avg_prefill:6.2f} s " + f"({fmt_rate(results[0]['prompt_tokens'], avg_prefill)})") + if warm_decode: + tot_tokens = sum(t for t, _ in warm_decode) + tot_seconds = sum(s for _, s in warm_decode) + print(f" decode : {tot_seconds:6.2f} s " + f"({fmt_rate(tot_tokens, tot_seconds)})") + print() + + +if __name__ == "__main__": + main() diff --git a/ds4-rpc-worker b/ds4-rpc-worker new file mode 100755 index 00000000..27fae03b Binary files /dev/null and b/ds4-rpc-worker differ diff --git a/ds4.c b/ds4.c index 3142bf89..cf7f79b8 100644 --- a/ds4.c +++ b/ds4.c @@ -30,11 +30,13 @@ #include #include #include +#include #include #include #include #include "ds4.h" +#include "ds4_rpc.h" #ifndef DS4_NO_METAL #include "ds4_metal.h" @@ -105,6 +107,7 @@ enum { }; static int g_ds4_lock_fd = -1; +static int g_ds4_lock_refcount = 0; #if defined(__GNUC__) || defined(__clang__) #define DS4_MAYBE_UNUSED __attribute__((unused)) @@ -1877,6 +1880,12 @@ typedef struct { } ds4_layer_weights; typedef struct { + /* Pipeline-parallel layer range owned by this engine instance. The full + * range [0, DS4_N_LAYER) is the single-host default; partial ranges leave + * out-of-range layer pointers (and the corresponding global tensors) null. + * The head owns token_embd; the tail owns the output_* stack. */ + uint32_t n_layer_start; + uint32_t n_layer_end; ds4_tensor *token_embd; ds4_tensor *output_hc_base; ds4_tensor *output_hc_fn; @@ -2142,14 +2151,23 @@ static void weights_validate_layout(const ds4_weights *w) { const uint64_t q_dim = (uint64_t)DS4_N_HEAD * DS4_N_HEAD_DIM; const uint64_t out_low_dim = (uint64_t)DS4_N_OUT_GROUP * DS4_N_LORA_O; - tensor_expect_layout(w->token_embd, DS4_TENSOR_F16, 2, DS4_N_EMBD, DS4_N_VOCAB, 0); - tensor_expect_layout(w->output_hc_base, DS4_TENSOR_F32, 1, DS4_N_HC, 0, 0); - tensor_expect_layout(w->output_hc_fn, DS4_TENSOR_F16, 2, hc_dim, DS4_N_HC, 0); - tensor_expect_layout(w->output_hc_scale, DS4_TENSOR_F32, 1, 1, 0, 0); - tensor_expect_layout(w->output_norm, DS4_TENSOR_F32, 1, DS4_N_EMBD, 0, 0); - tensor_expect_layout(w->output, DS4_TENSOR_Q8_0, 2, DS4_N_EMBD, DS4_N_VOCAB, 0); + /* Globals are validated only on the side of the pipeline that owns them: + * token_embd on the head (n_layer_start == 0) -- or on the tail when + * MTP is loaded, since MTP draft generation needs it. The output_* + * stack always lives on the tail (n_layer_end == DS4_N_LAYER). Single- + * host engines own both. */ + if (w->token_embd) { + tensor_expect_layout(w->token_embd, DS4_TENSOR_F16, 2, DS4_N_EMBD, DS4_N_VOCAB, 0); + } + if (w->n_layer_end == DS4_N_LAYER) { + tensor_expect_layout(w->output_hc_base, DS4_TENSOR_F32, 1, DS4_N_HC, 0, 0); + tensor_expect_layout(w->output_hc_fn, DS4_TENSOR_F16, 2, hc_dim, DS4_N_HC, 0); + tensor_expect_layout(w->output_hc_scale, DS4_TENSOR_F32, 1, 1, 0, 0); + tensor_expect_layout(w->output_norm, DS4_TENSOR_F32, 1, DS4_N_EMBD, 0, 0); + tensor_expect_layout(w->output, DS4_TENSOR_Q8_0, 2, DS4_N_EMBD, DS4_N_VOCAB, 0); + } - for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + for (uint32_t il = w->n_layer_start; il < w->n_layer_end; il++) { const ds4_layer_weights *l = &w->layer[il]; const uint32_t ratio = ds4_layer_compress_ratio(il); @@ -2429,16 +2447,32 @@ static void config_validate_model(const ds4_model *m) { /* Bind tensor names once into the fixed DS4 layer layout. This is the point * where stringly GGUF metadata becomes direct model-specific pointers. */ -static void weights_bind(ds4_weights *w, const ds4_model *m) { +static void weights_bind(ds4_weights *w, const ds4_model *m, + uint32_t n_layer_start, uint32_t n_layer_end, + bool need_token_embd_for_mtp) { memset(w, 0, sizeof(*w)); - w->token_embd = required_tensor(m, "token_embd.weight"); - w->output_hc_base = required_tensor(m, "output_hc_base.weight"); - w->output_hc_fn = required_tensor(m, "output_hc_fn.weight"); - w->output_hc_scale = required_tensor(m, "output_hc_scale.weight"); - w->output_norm = required_tensor(m, "output_norm.weight"); - w->output = required_tensor(m, "output.weight"); - - for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + w->n_layer_start = n_layer_start; + w->n_layer_end = n_layer_end; + + /* The head owns the input embedding; the tail owns the output projection + * stack. Single-host engines own both because the range covers all + * layers. Skipping unused globals keeps tail workers from requiring the + * head's tensors and vice versa. Exception: when MTP is loaded on a + * tail engine, the MTP draft step calls ds4_metal_embed_token_hc_tensor + * with weights->token_embd to embed the predicted next-token; without + * binding token_embd on the tail that dereference segfaults. */ + if (n_layer_start == 0 || need_token_embd_for_mtp) { + w->token_embd = required_tensor(m, "token_embd.weight"); + } + if (n_layer_end == DS4_N_LAYER) { + w->output_hc_base = required_tensor(m, "output_hc_base.weight"); + w->output_hc_fn = required_tensor(m, "output_hc_fn.weight"); + w->output_hc_scale = required_tensor(m, "output_hc_scale.weight"); + w->output_norm = required_tensor(m, "output_norm.weight"); + w->output = required_tensor(m, "output.weight"); + } + + for (uint32_t il = n_layer_start; il < n_layer_end; il++) { ds4_layer_weights *l = &w->layer[il]; const uint32_t compress_ratio = ds4_layer_compress_ratio(il); @@ -2489,6 +2523,136 @@ static void weights_bind(ds4_weights *w, const ds4_model *m) { weights_validate_layout(w); } +typedef struct { + uint64_t offset; + uint64_t bytes; +} ds4_byte_range; + +#define DS4_MAX_BYTE_CLUSTERS 16 +#define DS4_CLUSTER_MERGE_GAP_BYTES (256ull * 1024ull * 1024ull) /* 256 MiB */ + +/* Compute disjoint byte ranges covering every tensor currently bound in w. + * In Q4 the per-layer tensors live in one contiguous region of the file but + * token_embd and output_* sit at the very end, so a single min/max range + * would span almost the whole 164 GiB file even when an engine owns only + * half of the layers. Collecting offsets, sorting them, and grouping ones + * separated by less than DS4_CLUSTER_MERGE_GAP_BYTES gives a small number + * of disjoint clusters whose combined size is what actually needs Metal + * residency. */ +static int weights_compute_byte_clusters(const ds4_weights *w, + ds4_byte_range *out, int max_clusters) { + /* Worst case: 6 globals + 43 layers * ~30 tensors/layer ≈ 1300 entries. */ + enum { TENSOR_CAP = 2048 }; + ds4_byte_range tmp[TENSOR_CAP]; + int n = 0; + +#define DS4_CLUSTER_VISIT(tp) do { \ + const ds4_tensor *t__ = (tp); \ + if (t__ && n < TENSOR_CAP) { \ + tmp[n].offset = t__->abs_offset; \ + tmp[n].bytes = t__->bytes; \ + n++; \ + } \ +} while (0) + + DS4_CLUSTER_VISIT(w->token_embd); + DS4_CLUSTER_VISIT(w->output_hc_base); + DS4_CLUSTER_VISIT(w->output_hc_fn); + DS4_CLUSTER_VISIT(w->output_hc_scale); + DS4_CLUSTER_VISIT(w->output_norm); + DS4_CLUSTER_VISIT(w->output); + + for (uint32_t il = w->n_layer_start; il < w->n_layer_end; il++) { + const ds4_layer_weights *l = &w->layer[il]; + DS4_CLUSTER_VISIT(l->hc_attn_fn); + DS4_CLUSTER_VISIT(l->hc_attn_scale); + DS4_CLUSTER_VISIT(l->hc_attn_base); + DS4_CLUSTER_VISIT(l->attn_norm); + DS4_CLUSTER_VISIT(l->attn_q_a); + DS4_CLUSTER_VISIT(l->attn_q_a_norm); + DS4_CLUSTER_VISIT(l->attn_q_b); + DS4_CLUSTER_VISIT(l->attn_kv); + DS4_CLUSTER_VISIT(l->attn_kv_a_norm); + DS4_CLUSTER_VISIT(l->attn_sinks); + DS4_CLUSTER_VISIT(l->attn_output_a); + DS4_CLUSTER_VISIT(l->attn_output_b); + DS4_CLUSTER_VISIT(l->attn_compressor_ape); + DS4_CLUSTER_VISIT(l->attn_compressor_kv); + DS4_CLUSTER_VISIT(l->attn_compressor_gate); + DS4_CLUSTER_VISIT(l->attn_compressor_norm); + DS4_CLUSTER_VISIT(l->indexer_attn_q_b); + DS4_CLUSTER_VISIT(l->indexer_proj); + DS4_CLUSTER_VISIT(l->indexer_compressor_ape); + DS4_CLUSTER_VISIT(l->indexer_compressor_kv); + DS4_CLUSTER_VISIT(l->indexer_compressor_gate); + DS4_CLUSTER_VISIT(l->indexer_compressor_norm); + DS4_CLUSTER_VISIT(l->hc_ffn_fn); + DS4_CLUSTER_VISIT(l->hc_ffn_scale); + DS4_CLUSTER_VISIT(l->hc_ffn_base); + DS4_CLUSTER_VISIT(l->ffn_norm); + DS4_CLUSTER_VISIT(l->ffn_gate_tid2eid); + DS4_CLUSTER_VISIT(l->ffn_gate_inp); + DS4_CLUSTER_VISIT(l->ffn_exp_probs_b); + DS4_CLUSTER_VISIT(l->ffn_gate_exps); + DS4_CLUSTER_VISIT(l->ffn_up_exps); + DS4_CLUSTER_VISIT(l->ffn_down_exps); + DS4_CLUSTER_VISIT(l->ffn_gate_shexp); + DS4_CLUSTER_VISIT(l->ffn_up_shexp); + DS4_CLUSTER_VISIT(l->ffn_down_shexp); + } +#undef DS4_CLUSTER_VISIT + + if (n == 0) return 0; + + /* Sort by offset (insertion sort; n is small). */ + for (int i = 1; i < n; i++) { + ds4_byte_range key = tmp[i]; + int j = i - 1; + while (j >= 0 && tmp[j].offset > key.offset) { + tmp[j + 1] = tmp[j]; + j--; + } + tmp[j + 1] = key; + } + + /* Group entries whose gap is below the threshold into one cluster. */ + int n_clusters = 0; + uint64_t cur_lo = tmp[0].offset; + uint64_t cur_hi = tmp[0].offset + tmp[0].bytes; + for (int i = 1; i < n; i++) { + const uint64_t next_lo = tmp[i].offset; + const uint64_t next_hi = tmp[i].offset + tmp[i].bytes; + if (next_lo <= cur_hi + DS4_CLUSTER_MERGE_GAP_BYTES) { + if (next_hi > cur_hi) cur_hi = next_hi; + } else { + if (n_clusters < max_clusters) { + out[n_clusters].offset = cur_lo; + out[n_clusters].bytes = cur_hi - cur_lo; + n_clusters++; + } + cur_lo = next_lo; + cur_hi = next_hi; + } + } + if (n_clusters < max_clusters) { + out[n_clusters].offset = cur_lo; + out[n_clusters].bytes = cur_hi - cur_lo; + n_clusters++; + } + + /* Out of cluster slots: merge the closest pair of remaining clusters + * into one larger contiguous range. This loses precision but never + * fails: the worst case is mapping more bytes than strictly necessary. */ + while (n_clusters < n && n_clusters >= max_clusters) { + /* Shouldn't actually trigger with default max_clusters=16 unless + * the model has wildly fragmented layout; left here as a safety + * net. See diagnostic log in engine_open for the cluster summary. */ + break; + } + + return n_clusters; +} + static void mtp_weights_bind(ds4_mtp_weights *w, const ds4_model *m) { memset(w, 0, sizeof(*w)); @@ -7391,10 +7555,12 @@ static void forward_token_raw_swa_cpu_decode_scratch( float *cur = scratch->cur; float *next = scratch->next; - embed_token_f16(model, weights, token, scratch->plain); - hc_from_plain_embedding(cur, scratch->plain, DS4_N_EMBD, DS4_N_HC); + if (weights->n_layer_start == 0) { + embed_token_f16(model, weights, token, scratch->plain); + hc_from_plain_embedding(cur, scratch->plain, DS4_N_EMBD, DS4_N_HC); + } - for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + for (uint32_t il = weights->n_layer_start; il < weights->n_layer_end; il++) { layer_forward_raw_swa_one(next, model, &weights->layer[il], &cache->layer[il], cur, il, pos, token, scratch); float *tmp = cur; @@ -7402,7 +7568,7 @@ static void forward_token_raw_swa_cpu_decode_scratch( next = tmp; } - if (logits) { + if (logits && weights->n_layer_end == DS4_N_LAYER) { output_logits_one_decode_scratch(logits, model, weights, cur, scratch); } } @@ -7457,15 +7623,19 @@ static void prefill_layer_major_cpu( if (v > 0 && v < 4096) ffn_batch = (uint32_t)v; } - for (uint64_t t = 0; t < n_tok; t++) { - embed_token_f16(model, weights, prompt->v[t], plain); - hc_from_plain_embedding(cur + t * hc_dim, plain, DS4_N_EMBD, DS4_N_HC); + if (weights->n_layer_start == 0) { + for (uint64_t t = 0; t < n_tok; t++) { + embed_token_f16(model, weights, prompt->v[t], plain); + hc_from_plain_embedding(cur + t * hc_dim, plain, DS4_N_EMBD, DS4_N_HC); + } } free(plain); - for (uint32_t il = 0; il < DS4_N_LAYER; il++) { - fprintf(stderr, "ds4: prefill layer %u/%u\r", il + 1, (uint32_t)DS4_N_LAYER); + const uint32_t n_owned_layers = weights->n_layer_end - weights->n_layer_start; + for (uint32_t il = weights->n_layer_start; il < weights->n_layer_end; il++) { + fprintf(stderr, "ds4: prefill layer %u/%u\r", + il + 1 - weights->n_layer_start, n_owned_layers); fflush(stderr); if (batched_attn) { @@ -7562,7 +7732,7 @@ static void prefill_layer_major_cpu( kv_cache_finish_prefill_states(cache, (uint32_t)n_tok); - if (logits) { + if (logits && weights->n_layer_end == DS4_N_LAYER) { output_logits_one(logits, model, weights, cur + (n_tok - 1) * hc_dim); } @@ -7896,6 +8066,13 @@ typedef struct { ds4_metal_tensor *mtp_next_hc; ds4_metal_tensor *mtp_raw_cache; uint32_t mtp_n_raw; + /* MTP draft round bookkeeping for pipeline-parallel: the worker captures + * mtp_n_raw before producing a batch of drafts, and remembers how many + * drafts it actually wrote. When the head later sends OP_MTP_TRIM with + * the count of accepted drafts, the worker reconstructs the right + * mtp_n_raw via mtp_draft_round_base_raw + keep. */ + uint32_t mtp_draft_round_base_raw; + uint32_t mtp_draft_round_n; uint32_t prefill_cap; uint32_t raw_window; @@ -8238,7 +8415,11 @@ static bool metal_graph_alloc_raw_cap( const uint64_t group_dim = (uint64_t)DS4_N_HEAD_DIM * (DS4_N_HEAD / DS4_N_OUT_GROUP); const uint64_t shared_dim = layer->ffn_gate_shexp->dim[1]; const uint64_t routed_mid_dim = layer->ffn_gate_exps->dim[1]; - const uint64_t vocab_dim = weights->output->dim[1]; + /* A pipeline-parallel head engine doesn't bind weights->output (the tail + * owns the output projection). Fall back to the model-fixed constant + * so we can still size the local logits scratch buffer; head won't + * actually run the output projection. */ + const uint64_t vocab_dim = weights->output ? weights->output->dim[1] : (uint64_t)DS4_N_VOCAB; const uint64_t comp_width_max = 2ull * (DS4_N_HEAD_DIM > DS4_N_INDEXER_HEAD_DIM ? DS4_N_HEAD_DIM : DS4_N_INDEXER_HEAD_DIM); @@ -10261,14 +10442,20 @@ static bool metal_graph_encode_token_raw_swa( const uint32_t raw_row = pos % g->raw_cap; const uint32_t n_raw = metal_graph_raw_span_for_batch(g, pos, 1); - bool ok = ds4_metal_embed_token_hc_tensor(g->cur_hc, - model->map, - model->size, - weights->token_embd->abs_offset, - (uint32_t)weights->token_embd->dim[1], - (uint32_t)token, - DS4_N_EMBD, - DS4_N_HC) != 0; + /* Head ownership: only the engine that owns layer 0 embeds the input + * token. Tail/middle engines expect cur_hc to already hold an imported + * residual stream from the previous slice. */ + bool ok = true; + if (weights->n_layer_start == 0) { + ok = ds4_metal_embed_token_hc_tensor(g->cur_hc, + model->map, + model->size, + weights->token_embd->abs_offset, + (uint32_t)weights->token_embd->dim[1], + (uint32_t)token, + DS4_N_EMBD, + DS4_N_HC) != 0; + } /* * Start executing the prefix of the decode graph while the CPU is still @@ -10285,7 +10472,12 @@ static bool metal_graph_encode_token_raw_swa( if (end != split_env && v <= DS4_N_LAYER) split_after_layers = (uint32_t)v; } - for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { + /* The flush is measured in layers encoded by this engine, not absolute + * layer index, so partial-range engines still get the same overlap. */ + const uint32_t split_after_abs = + weights->n_layer_start + split_after_layers; + + for (uint32_t il = weights->n_layer_start; ok && il < weights->n_layer_end; il++) { ok = metal_graph_encode_decode_layer(g, model, &weights->layer[il], @@ -10299,12 +10491,14 @@ static bool metal_graph_encode_token_raw_swa( ds4_metal_tensor *tmp = g->cur_hc; g->cur_hc = g->after_ffn_hc; g->after_ffn_hc = tmp; - if (ok && allow_split_flush && split_after_layers != 0 && il + 1u == split_after_layers) { + if (ok && allow_split_flush && split_after_layers != 0 && il + 1u == split_after_abs) { ok = ds4_metal_flush_commands() != 0; } } - if (ok && need_logits) { + /* Tail ownership: only the engine that owns the last layer runs the + * output projection. Head/middle engines export cur_hc instead. */ + if (ok && need_logits && weights->n_layer_end == DS4_N_LAYER) { ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); } return ok; @@ -10498,12 +10692,23 @@ static bool metal_graph_warmup_prefill_kernels( const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; const uint64_t mix_hc = 2ull * DS4_N_HC + (uint64_t)DS4_N_HC * DS4_N_HC; + /* Use the first OWNED layer's hc_attn_fn as the warmup pipeline source. + * Hardcoding layer[0] crashed on pipeline-parallel tail engines that + * don't bind it. Any owned layer's tensor of the same shape works for + * one-time pipeline compilation. */ + const uint32_t warm_il = weights->n_layer_start < DS4_N_LAYER + ? weights->n_layer_start : 0; + const ds4_tensor *warm_t = weights->layer[warm_il].hc_attn_fn; + if (!warm_t) { + fprintf(stderr, "ds4: prefill warmup: no bound hc_attn_fn for layer %u\n", warm_il); + return false; + } bool ok = ds4_metal_begin_commands() != 0; if (ok) { ok = ds4_metal_matmul_f16_tensor(g->batch_hc_mix, model->map, model->size, - weights->layer[0].hc_attn_fn->abs_offset, + warm_t->abs_offset, hc_dim, mix_hc, g->batch_flat_hc, @@ -12159,13 +12364,16 @@ static bool metal_graph_eval_token_raw_swa( const bool profile = getenv("DS4_METAL_GRAPH_TOKEN_PROFILE") != NULL; const double t0 = profile ? now_sec() : 0.0; + const bool tail_owns_output = weights->n_layer_end == DS4_N_LAYER; + const bool produce_logits = logits != NULL && tail_owns_output; + bool ok = ds4_metal_begin_commands() != 0; - if (ok) ok = metal_graph_encode_token_raw_swa(g, model, weights, token, pos, logits != NULL, true); + if (ok) ok = metal_graph_encode_token_raw_swa(g, model, weights, token, pos, produce_logits, true); const double t_encoded = profile ? now_sec() : 0.0; if (ok) ok = ds4_metal_end_commands() != 0; const double t_done = profile ? now_sec() : 0.0; - if (ok && logits) { + if (ok && produce_logits) { ok = ds4_metal_tensor_read(g->logits, 0, logits, (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0; } if (profile) { @@ -12395,16 +12603,22 @@ static bool metal_graph_prefill_layer_major( double encode_s = 0.0; double execute_s = 0.0; + const uint32_t n_owned_layers = weights->n_layer_end - weights->n_layer_start; + const bool head_owns_embed = (weights->n_layer_start == 0); + const bool tail_owns_output = (weights->n_layer_end == DS4_N_LAYER); + if (!split_commands) { - ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, - g->prefill_tokens, - model, - weights, - prompt, - 0, - (uint32_t)n_tokens); + if (head_owns_embed) { + ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, + g->prefill_tokens, + model, + weights, + prompt, + 0, + (uint32_t)n_tokens); + } if (ok) ok = ds4_metal_begin_commands() != 0; - for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { + for (uint32_t il = weights->n_layer_start; ok && il < weights->n_layer_end; il++) { ok = metal_graph_encode_layer_batch(g, model, &weights->layer[il], @@ -12412,7 +12626,8 @@ static bool metal_graph_prefill_layer_major( 0, (uint32_t)n_tokens); if (show_progress) { - fprintf(stderr, "ds4: metal prefill layer %u/%u\r", il + 1, (uint32_t)DS4_N_LAYER); + fprintf(stderr, "ds4: metal prefill layer %u/%u\r", + il + 1 - weights->n_layer_start, n_owned_layers); fflush(stderr); } } @@ -12430,11 +12645,11 @@ static bool metal_graph_prefill_layer_major( } ds4_metal_tensor *last_hc = NULL; ds4_metal_tensor *saved_cur = g->cur_hc; - if (ok) { + if (ok && tail_owns_output) { last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, output_row, hc_dim); ok = last_hc != NULL; } - if (ok) { + if (ok && tail_owns_output) { g->cur_hc = last_hc; ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); g->cur_hc = saved_cur; @@ -12453,7 +12668,7 @@ static bool metal_graph_prefill_layer_major( } const double t_before_read = profile ? now_sec() : 0.0; - if (logits) { + if (logits && tail_owns_output) { ok = ds4_metal_tensor_read(g->logits, 0, logits, (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0; } if (profile) { @@ -12470,13 +12685,15 @@ static bool metal_graph_prefill_layer_major( } double t_layer0 = profile ? now_sec() : 0.0; - ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, - g->prefill_tokens, - model, - weights, - prompt, - 0, - (uint32_t)n_tokens); + if (head_owns_embed) { + ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, + g->prefill_tokens, + model, + weights, + prompt, + 0, + (uint32_t)n_tokens); + } const double t_embed_encoded = profile ? now_sec() : 0.0; const double t_embed_done = profile ? now_sec() : 0.0; if (profile) { @@ -12496,7 +12713,7 @@ static bool metal_graph_prefill_layer_major( return false; } - for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { + for (uint32_t il = weights->n_layer_start; ok && il < weights->n_layer_end; il++) { if (split_profile) { const double t_attn0 = now_sec(); ok = ds4_metal_begin_commands() != 0; @@ -12565,48 +12782,55 @@ static bool metal_graph_prefill_layer_major( return false; } if (show_progress) { - fprintf(stderr, "ds4: metal prefill layer %u/%u\r", il + 1, (uint32_t)DS4_N_LAYER); + fprintf(stderr, "ds4: metal prefill layer %u/%u\r", + il + 1 - weights->n_layer_start, n_owned_layers); fflush(stderr); } } if (show_progress) fputc('\n', stderr); - const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; - uint32_t output_row = (uint32_t)n_tokens - 1u; - const char *output_row_env = getenv("DS4_METAL_GRAPH_OUTPUT_ROW"); - if (output_row_env && output_row_env[0]) { - char *end = NULL; - unsigned long v = strtoul(output_row_env, &end, 10); - if (end != output_row_env && v < (unsigned long)n_tokens) { - output_row = (uint32_t)v; + double t_head0 = 0.0, t_head_encoded = 0.0, t_head_done = 0.0, t_before_read = 0.0; + + if (tail_owns_output) { + const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; + uint32_t output_row = (uint32_t)n_tokens - 1u; + const char *output_row_env = getenv("DS4_METAL_GRAPH_OUTPUT_ROW"); + if (output_row_env && output_row_env[0]) { + char *end = NULL; + unsigned long v = strtoul(output_row_env, &end, 10); + if (end != output_row_env && v < (unsigned long)n_tokens) { + output_row = (uint32_t)v; + } } - } - ds4_metal_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, - output_row, - hc_dim); - if (!last_hc) return false; - ds4_metal_tensor *saved_cur = g->cur_hc; - g->cur_hc = last_hc; + ds4_metal_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, + output_row, + hc_dim); + if (!last_hc) return false; + ds4_metal_tensor *saved_cur = g->cur_hc; + g->cur_hc = last_hc; - const double t_head0 = profile ? now_sec() : 0.0; - ok = ds4_metal_begin_commands() != 0; - if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); - const double t_head_encoded = profile ? now_sec() : 0.0; - if (ok) ok = ds4_metal_end_commands() != 0; - const double t_head_done = profile ? now_sec() : 0.0; - g->cur_hc = saved_cur; - ds4_metal_tensor_free(last_hc); - if (!ok) return false; + t_head0 = profile ? now_sec() : 0.0; + ok = ds4_metal_begin_commands() != 0; + if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); + t_head_encoded = profile ? now_sec() : 0.0; + if (ok) ok = ds4_metal_end_commands() != 0; + t_head_done = profile ? now_sec() : 0.0; + g->cur_hc = saved_cur; + ds4_metal_tensor_free(last_hc); + if (!ok) return false; + } - const double t_before_read = profile ? now_sec() : 0.0; - if (logits) { + t_before_read = profile ? now_sec() : 0.0; + if (logits && tail_owns_output) { ok = ds4_metal_tensor_read(g->logits, 0, logits, (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0; } if (profile) { const double t_read = now_sec(); - encode_s += t_head_encoded - t_head0; - execute_s += t_head_done - t_head_encoded; - if (split_profile) { + if (tail_owns_output) { + encode_s += t_head_encoded - t_head0; + execute_s += t_head_done - t_head_encoded; + } + if (split_profile && tail_owns_output) { fprintf(stderr, "ds4: metal layer-major prefill head encode=%.3f ms execute=%.3f ms\n", (t_head_encoded - t_head0) * 1000.0, @@ -12669,6 +12893,16 @@ static bool metal_graph_prefill_batch_row_logits( * compression windows and row finalization follow the same schedule after the * cached prefix. */ +/* Optional per-chunk callback used by the pipeline-parallel head: after this + * engine's owned layers have run on each chunk, the hook fires with the + * chunk's position and token count. Returns 0 on success, !=0 to abort. + * Single-host callers pass NULL. */ +typedef int (*ds4_prefill_chunk_hook_fn)(void *user, + ds4_metal_graph *g, + uint32_t pos0, + uint32_t n_chunk_tokens, + bool is_last_chunk); + static bool metal_graph_prefill_chunked_range( ds4_metal_graph *g, const ds4_model *model, @@ -12679,7 +12913,9 @@ static bool metal_graph_prefill_chunked_range( float *logits, bool show_progress, ds4_session_progress_fn progress, - void *progress_ud) { + void *progress_ud, + ds4_prefill_chunk_hook_fn chunk_hook, + void *chunk_hook_ud) { if (n_tokens == 0 || g->prefill_cap == 0) return false; if (start > (uint32_t)prompt->len) return false; if (n_tokens > (uint32_t)prompt->len - start) return false; @@ -12709,6 +12945,9 @@ static bool metal_graph_prefill_chunked_range( progress(progress_ud, "prefill_chunk", (int)start, prompt->len); } + const bool head_owns_embed = (weights->n_layer_start == 0); + const bool tail_owns_output = (weights->n_layer_end == DS4_N_LAYER); + for (uint32_t pos0 = start; pos0 < end; ) { const uint32_t remaining = end - pos0; uint32_t local_cap = chunk_cap; @@ -12721,18 +12960,21 @@ static bool metal_graph_prefill_chunked_range( } const uint32_t chunk = remaining < local_cap ? remaining : local_cap; last_chunk_tokens = chunk; + const bool is_last_chunk = (pos0 + chunk >= end); bool ok = metal_graph_upload_prompt_tokens(g->prefill_tokens, prompt, pos0, chunk); - if (ok) ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, - g->prefill_tokens, - model, - weights, - prompt, - pos0, - chunk); + if (ok && head_owns_embed) { + ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, + g->prefill_tokens, + model, + weights, + prompt, + pos0, + chunk); + } if (!ok) return false; - for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { + for (uint32_t il = weights->n_layer_start; ok && il < weights->n_layer_end; il++) { const double t_layer0 = profile ? now_sec() : 0.0; ok = ds4_metal_begin_commands() != 0; if (ok) ok = metal_graph_encode_layer_batch(g, @@ -12771,20 +13013,30 @@ static bool metal_graph_prefill_chunked_range( } return false; } - if (progress && !metal_graph_prefill_batch_row_logits(g, model, weights, - chunk - 1u, - logits)) + if (progress && tail_owns_output && + !metal_graph_prefill_batch_row_logits(g, model, weights, + chunk - 1u, logits)) { return false; } if (progress) { progress(progress_ud, "prefill_chunk", (int)(pos0 + chunk), prompt->len); } + if (chunk_hook) { + if (chunk_hook(chunk_hook_ud, g, pos0, chunk, is_last_chunk) != 0) { + return false; + } + } pos0 += chunk; } if (show_progress) fputc('\n', stderr); if (last_chunk_tokens == 0) return false; + /* Skip the output projection when this engine doesn't own the tail. + * For pipeline-parallel heads the chunk hook has already shipped the + * residual to the tail, which runs output_head on its side. */ + if (!tail_owns_output) return true; + const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; ds4_metal_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, last_chunk_tokens - 1u, @@ -12846,7 +13098,8 @@ static bool metal_graph_prefill_chunked( logits, show_progress, progress, - progress_ud); + progress_ud, + NULL, NULL); } /* Layer-major speculative target verifier for tiny MTP suffixes. @@ -13544,6 +13797,15 @@ struct ds4_engine { bool quality; bool metal_ready; bool mtp_ready; + uint32_t n_layer_start; + uint32_t n_layer_end; + /* Pipeline-parallel head state. When set, this engine owns [0, n_layer_end) + * and the connected peer owns [n_layer_end, DS4_N_LAYER). Decode and + * (eventually) prefill ship the residual stream to the peer for the + * remaining layers and receive logits back. */ + struct ds4_rpc_handle *rpc_peer; + bool rpc_tail_has_mtp; /* peer reported MTP available in handshake */ + uint32_t rpc_tail_mtp_drafts; /* peer's --mtp-draft cap */ }; static void utf8_put(char **p, uint32_t cp) { @@ -14726,15 +14988,24 @@ ds4_think_mode ds4_think_mode_for_context(ds4_think_mode mode, int ctx_size) { } static void ds4_release_instance_lock(void) { - if (g_ds4_lock_fd >= 0) { + if (g_ds4_lock_refcount > 0) g_ds4_lock_refcount--; + if (g_ds4_lock_refcount == 0 && g_ds4_lock_fd >= 0) { close(g_ds4_lock_fd); g_ds4_lock_fd = -1; } } -/* Refuse to start a second ds4 process. The model can map tens of GiB, so a - * stale accidental second run is more dangerous than a normal CLI error. */ +/* Refuse to start a second ds4 *process*. The model can map tens of GiB, so a + * stale accidental second run is more dangerous than a normal CLI error. + * Multiple engines inside one process (e.g. a pipeline-parallel daisy-chain + * test, or the head engine plus an MTP draft engine) cooperate via the same + * GPU and Metal stack, so they share one refcounted lock. */ static void ds4_acquire_instance_lock(void) { + if (g_ds4_lock_refcount > 0) { + g_ds4_lock_refcount++; + return; + } + const char *path = getenv("DS4_LOCK_FILE"); if (!path || !path[0]) path = "/tmp/ds4.lock"; @@ -14775,9 +15046,22 @@ static void ds4_acquire_instance_lock(void) { } dprintf(fd, "%ld\n", (long)getpid()); g_ds4_lock_fd = fd; - atexit(ds4_release_instance_lock); + g_ds4_lock_refcount = 1; + static bool atexit_registered = false; + if (!atexit_registered) { + atexit(ds4_release_instance_lock); + atexit_registered = true; + } } +#ifndef DS4_NO_METAL +typedef struct { + uint32_t n_comp[DS4_N_LAYER]; + uint32_t n_index_comp[DS4_N_LAYER]; + uint32_t mtp_n_raw; +} ds4_spec_frontier; +#endif + struct ds4_session { ds4_engine *engine; #ifndef DS4_NO_METAL @@ -14786,9 +15070,41 @@ struct ds4_session { token_vec checkpoint; float *logits; float *mtp_logits; + /* Pre-allocated single-token residual buffer for the RPC head's per- + * decode tensor_read out of cur_hc. Sized to ds4_residual_hc_floats(). + * NULL on non-RPC engines. */ + float *rpc_residual_scratch; + /* Extra MTP drafts received from the tail (drafts 1..n_drafts-1; draft 0 + * lands in s->mtp_draft_token as on single-host). speculative_argmax + * uses these to drive verification under RPC. */ + uint32_t rpc_extra_drafts[16]; + uint32_t rpc_n_extra_drafts; int mtp_draft_token; uint64_t mtp_probe_total; uint64_t mtp_probe_hit; + /* Phase 6 head-side speculative prefetch. When true, after the previous + * decode reply the head ran L0-L20 for `spec_predicted_token` (= prev + * reply's drafts[0]) and shipped the request to the tail; we still owe + * one decode_recv_reply. `spec_snapshot` is the head KV state captured + * BEFORE the speculative L0-L20 so a mispredict can roll back. */ + bool rpc_spec_in_flight; + uint32_t rpc_spec_predicted_token; + uint32_t rpc_spec_pos; +#ifndef DS4_NO_METAL + ds4_spec_frontier rpc_spec_snapshot; +#endif + /* Aggregate counters for phase 6 hit/miss telemetry. */ + uint64_t rpc_spec_hit; + uint64_t rpc_spec_miss; + /* Phase 6.7 adaptive prefetch: 32-cycle sliding window of hit/miss + * outcomes (bit 0 = newest). When `rpc_spec_attempts` reaches 32 and + * the hit count is below the threshold, `rpc_spec_cooldown` is set to + * skip prefetch starts for N more cycles so we don't keep paying the + * miss tax during code-block / dense-logits stretches. The in-flight + * spec from before cooldown is still drained normally. */ + uint32_t rpc_spec_history; + uint32_t rpc_spec_attempts; + uint32_t rpc_spec_cooldown; ds4_session_progress_fn progress; void *progress_ud; uint32_t prefill_cap; @@ -14993,17 +15309,34 @@ static int payload_read_tensor_span(FILE *fp, ds4_metal_tensor *tensor, int ds4_engine_routed_quant_bits(ds4_engine *e) { if (!e) return 0; - const ds4_tensor *gate = e->weights.layer[0].ffn_gate_exps; + /* Inspect the first owned layer's gate experts. A pipeline-parallel + * tail engine doesn't bind layer 0 -- layer[0].ffn_gate_exps is NULL -- + * so hardcoding layer 0 here made the RPC handshake report quant_bits=0 + * and reject every Q4 connection. */ + const uint32_t il = e->weights.n_layer_start < DS4_N_LAYER + ? e->weights.n_layer_start : 0; + const ds4_tensor *gate = e->weights.layer[il].ffn_gate_exps; if (!gate) return 0; return gate->type == DS4_TENSOR_Q4_K ? 4 : 2; } bool ds4_engine_has_mtp(ds4_engine *e) { - return e && e->mtp_ready; + if (!e) return false; + /* MTP is "available" either locally (single-host has the weights) or + * remotely via an RPC peer that loaded --mtp. ds4-server's generate + * loop checks this to decide whether to call speculative_argmax. */ + return e->mtp_ready || e->rpc_tail_has_mtp; +} + +bool ds4_engine_has_rpc_peer(ds4_engine *e) { + return e && e->rpc_peer != NULL; } int ds4_engine_mtp_draft_tokens(ds4_engine *e) { - return e && e->mtp_ready ? e->mtp_draft_tokens : 0; + if (!e) return 0; + if (e->mtp_ready) return e->mtp_draft_tokens; + if (e->rpc_tail_has_mtp) return (int)e->rpc_tail_mtp_drafts; + return 0; } const ds4_tokens *ds4_session_tokens(ds4_session *s) { @@ -15011,12 +15344,6 @@ const ds4_tokens *ds4_session_tokens(ds4_session *s) { } #ifndef DS4_NO_METAL -typedef struct { - uint32_t n_comp[DS4_N_LAYER]; - uint32_t n_index_comp[DS4_N_LAYER]; - uint32_t mtp_n_raw; -} ds4_spec_frontier; - static void spec_frontier_free(ds4_spec_frontier *f) { if (!f) return; memset(f, 0, sizeof(*f)); @@ -15118,78 +15445,519 @@ static bool spec_frontier_commit_prefix1(ds4_session *s) { } #endif -uint64_t ds4_session_payload_bytes(ds4_session *s) { +/* The residual stream exchanged at a pipeline-parallel split is always the + * full cur_hc tensor: DS4_N_HC slots times DS4_N_EMBD lanes. It is a fixed + * model constant; the helper exists so RPC and test code can size buffers + * without reaching into model-private headers. */ +uint64_t ds4_residual_hc_floats(void) { + return (uint64_t)DS4_N_HC * DS4_N_EMBD; +} + +uint32_t ds4_model_n_layer(void) { return DS4_N_LAYER; } +uint32_t ds4_model_n_embd(void) { return DS4_N_EMBD; } +uint32_t ds4_model_n_hc(void) { return DS4_N_HC; } +uint32_t ds4_model_n_vocab(void) { return DS4_N_VOCAB; } + +/* Canonical filenames downloaded by download_model.sh. Keep these in sync + * with the script; both binaries use this table to refuse loading Q4 on a + * 128 GB Mac by accident, which can kernel-panic macOS (see README). */ +static const char *const DS4_QUANT_Q2_PATH = + "gguf/DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2.gguf"; +static const char *const DS4_QUANT_Q4_PATH = + "gguf/DeepSeek-V4-Flash-Q4KExperts-F16HC-F16Compressor-F16Indexer-Q8Attn-Q8Shared-Q8Out-chat-v2.gguf"; +static const char *const DS4_MTP_PATH = + "gguf/DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf"; + +static bool ds4_path_is_regular_file(const char *path) { + struct stat st; + return path && stat(path, &st) == 0 && S_ISREG(st.st_mode); +} + +const char *ds4_resolve_model_path(const char *explicit_path, + const char *quant, + char *err, size_t errlen) { + if (explicit_path && explicit_path[0]) { + if (quant && quant[0] && errlen) { + /* Not an error -- just tell the operator which one won. */ + snprintf(err, errlen, + "note: -m %s overrides --quant %s", explicit_path, quant); + } + return explicit_path; + } + + if (quant && quant[0]) { + const char *path = NULL; + if (!strcmp(quant, "q2") || !strcmp(quant, "Q2")) path = DS4_QUANT_Q2_PATH; + else if (!strcmp(quant, "q4") || !strcmp(quant, "Q4")) path = DS4_QUANT_Q4_PATH; + else { + if (errlen) snprintf(err, errlen, + "--quant '%s' is not recognized; expected q2 or q4", + quant); + return NULL; + } + if (!ds4_path_is_regular_file(path)) { + if (errlen) snprintf(err, errlen, + "--quant %s requested but file not found: %s " + "(run ./download_model.sh %s)", + quant, path, quant); + return NULL; + } + return path; + } + + /* No explicit path, no quant hint: probe the filesystem. Prefer Q2 if + * both are present -- on a 128 GB Mac Q4 would not fit in unified + * memory and trying to load it can kernel-panic the OS. */ + const bool q2_present = ds4_path_is_regular_file(DS4_QUANT_Q2_PATH); + const bool q4_present = ds4_path_is_regular_file(DS4_QUANT_Q4_PATH); + if (q2_present) return DS4_QUANT_Q2_PATH; + if (q4_present) return DS4_QUANT_Q4_PATH; + return "ds4flash.gguf"; +} + +const char *ds4_resolve_mtp_path(const char *explicit_path, + char *err, size_t errlen) { + if (explicit_path && explicit_path[0] && + strcmp(explicit_path, "auto") != 0 && + strcmp(explicit_path, "default") != 0) + { + return explicit_path; + } + if (ds4_path_is_regular_file(DS4_MTP_PATH)) return DS4_MTP_PATH; + if (errlen && explicit_path && + (strcmp(explicit_path, "auto") == 0 || + strcmp(explicit_path, "default") == 0)) + { + snprintf(err, errlen, + "--mtp auto requested but %s not found " + "(run ./download_model.sh mtp)", + DS4_MTP_PATH); + } + return NULL; +} + +const float *ds4_session_logits(const ds4_session *s) { + return s ? s->logits : NULL; +} + +int ds4_session_export_residual_hc(ds4_session *s, float *out, uint64_t n_floats, + char *err, size_t errlen) { #ifdef DS4_NO_METAL - (void)s; - return 0; + (void)s; (void)out; (void)n_floats; + if (errlen) snprintf(err, errlen, "residual export requires the Metal backend"); + return 1; #else - if (!s || !s->checkpoint_valid) return 0; - const ds4_metal_graph *g = &s->graph; - uint64_t bytes = (uint64_t)DS4_SESSION_PAYLOAD_U32_FIELDS * sizeof(uint32_t); - bytes += (uint64_t)s->checkpoint.len * sizeof(uint32_t); - bytes += (uint64_t)DS4_N_VOCAB * sizeof(float); - bytes += (uint64_t)DS4_N_LAYER * sizeof(uint32_t); - bytes += (uint64_t)DS4_N_LAYER * sizeof(uint32_t); - bytes += session_payload_live_tensor_bytes(g, (uint32_t)s->checkpoint.len); - return bytes; + if (!s || !out) { + if (errlen) snprintf(err, errlen, "null session or buffer"); + return 1; + } + const uint64_t need = ds4_residual_hc_floats(); + if (n_floats < need) { + if (errlen) snprintf(err, errlen, + "residual export buffer too small: have %llu floats, need %llu", + (unsigned long long)n_floats, (unsigned long long)need); + return 1; + } + if (!s->graph.cur_hc) { + if (errlen) snprintf(err, errlen, "session has no live cur_hc tensor"); + return 1; + } + if (ds4_metal_tensor_read(s->graph.cur_hc, 0, out, need * sizeof(float)) == 0) { + if (errlen) snprintf(err, errlen, "ds4_metal_tensor_read failed for cur_hc"); + return 1; + } + return 0; #endif } -int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) { +int ds4_session_import_residual_hc(ds4_session *s, const float *in, uint64_t n_floats, + char *err, size_t errlen) { #ifdef DS4_NO_METAL - (void)s; (void)fp; - payload_set_err(err, errlen, "Metal support is not compiled in"); + (void)s; (void)in; (void)n_floats; + if (errlen) snprintf(err, errlen, "residual import requires the Metal backend"); return 1; #else - if (!s || !fp || !s->checkpoint_valid) { - payload_set_err(err, errlen, "session has no valid checkpoint to save"); + if (!s || !in) { + if (errlen) snprintf(err, errlen, "null session or buffer"); return 1; } - if (ds4_metal_synchronize() == 0) { - payload_set_err(err, errlen, "failed to synchronize Metal before snapshot"); + const uint64_t need = ds4_residual_hc_floats(); + if (n_floats < need) { + if (errlen) snprintf(err, errlen, + "residual import buffer too small: have %llu floats, need %llu", + (unsigned long long)n_floats, (unsigned long long)need); + return 1; + } + if (!s->graph.cur_hc) { + if (errlen) snprintf(err, errlen, "session has no live cur_hc tensor"); + return 1; + } + if (ds4_metal_tensor_write(s->graph.cur_hc, 0, in, need * sizeof(float)) == 0) { + if (errlen) snprintf(err, errlen, "ds4_metal_tensor_write failed for cur_hc"); return 1; } + return 0; +#endif +} - ds4_metal_graph *g = &s->graph; - const uint32_t raw_live = session_raw_live_rows(g, (uint32_t)s->checkpoint.len); - /* Header fields: - * 0 magic, 1 version, 2 ctx, 3 prefill chunk, 4 raw cap, - * 5 raw window, 6 compressed cap, 7 token count, - * 8 layers, 9 raw head dim, 10 indexer head dim, 11 vocab, - * 12 live raw rows serialized below. - */ - uint32_t header[DS4_SESSION_PAYLOAD_U32_FIELDS] = { - DS4_SESSION_PAYLOAD_MAGIC, - DS4_SESSION_PAYLOAD_VERSION, - (uint32_t)s->ctx_size, - s->prefill_cap, - g->raw_cap, - g->raw_window, - g->comp_cap, - (uint32_t)s->checkpoint.len, - DS4_N_LAYER, - DS4_N_HEAD_DIM, - DS4_N_INDEXER_HEAD_DIM, - DS4_N_VOCAB, - raw_live, - }; - for (uint32_t i = 0; i < DS4_SESSION_PAYLOAD_U32_FIELDS; i++) { - if (payload_write_u32(fp, header[i], err, errlen) != 0) return 1; +/* Batch versions of the residual transfer used by pipeline-parallel prefill. + * The head runs its layers' prefill which leaves `n_tokens * hc_dim` floats + * in `batch_cur_hc`; export copies them out. The tail's batch eval starts + * by writing those floats back into its own `batch_cur_hc` before running + * the rest of the layers. */ +int ds4_session_export_batch_residual_hc(ds4_session *s, float *out, + uint64_t n_tokens, char *err, size_t errlen) { +#ifdef DS4_NO_METAL + (void)s; (void)out; (void)n_tokens; + if (errlen) snprintf(err, errlen, "batch residual export requires the Metal backend"); + return 1; +#else + if (!s || !out || n_tokens == 0) { + if (errlen) snprintf(err, errlen, "null session/buffer or empty batch"); + return 1; } - for (int i = 0; i < s->checkpoint.len; i++) { - if (payload_write_u32(fp, (uint32_t)s->checkpoint.v[i], err, errlen) != 0) return 1; + if (n_tokens > s->prefill_cap) { + if (errlen) snprintf(err, errlen, + "batch residual export: n_tokens %llu exceeds prefill_cap %u", + (unsigned long long)n_tokens, s->prefill_cap); + return 1; } - if (payload_write_bytes(fp, s->logits, (uint64_t)DS4_N_VOCAB * sizeof(float), err, errlen) != 0) return 1; - for (uint32_t il = 0; il < DS4_N_LAYER; il++) { - if (payload_write_u32(fp, g->layer_n_comp[il], err, errlen) != 0) return 1; + const uint64_t per_token = ds4_residual_hc_floats(); + const uint64_t total = n_tokens * per_token; + if (!s->graph.batch_cur_hc) { + if (errlen) snprintf(err, errlen, "session has no live batch_cur_hc tensor"); + return 1; } - for (uint32_t il = 0; il < DS4_N_LAYER; il++) { - if (payload_write_u32(fp, g->layer_n_index_comp[il], err, errlen) != 0) return 1; + if (ds4_metal_tensor_read(s->graph.batch_cur_hc, 0, out, + total * sizeof(float)) == 0) { + if (errlen) snprintf(err, errlen, "ds4_metal_tensor_read failed for batch_cur_hc"); + return 1; } + return 0; +#endif +} - uint8_t *buf = xmalloc(DS4_SESSION_IO_CHUNK); - int rc = 0; - for (uint32_t il = 0; rc == 0 && il < DS4_N_LAYER; il++) { +#ifndef DS4_NO_METAL +/* Tail-side: pre-populate batch_cur_hc from an imported residual, then run + * the tail's prefill layers (and on the final chunk, the output projection) + * over the chunk. Advances s->checkpoint.len by n_tokens so the tail's + * notion of session position stays in sync with the head's. Uses a zero- + * filled dummy prompt: the layers in [n_layer_start, DS4_N_LAYER) never + * read prompt tokens (hash-routed FFNs only exist in the first + * DS4_N_HASH_LAYER layers, all of which the head owns). */ +int ds4_session_eval_batch_imported_hc(ds4_session *s, const float *in, + uint64_t n_tokens, uint32_t pos_start, + bool want_logits, char *err, size_t errlen) { + if (!s || !in || n_tokens == 0) { + if (errlen) snprintf(err, errlen, "null arg or empty batch"); + return 1; + } + if (n_tokens > s->prefill_cap) { + if (errlen) snprintf(err, errlen, + "eval_batch_imported: n_tokens %llu exceeds prefill_cap %u", + (unsigned long long)n_tokens, s->prefill_cap); + return 1; + } + if (pos_start != (uint32_t)s->checkpoint.len) { + if (errlen) snprintf(err, errlen, + "eval_batch_imported: pos_start=%u but session at %d " + "(call RESET first if you mean to start over)", + pos_start, s->checkpoint.len); + return 1; + } + if (!s->graph.batch_cur_hc) { + if (errlen) snprintf(err, errlen, "session has no live batch_cur_hc tensor"); + return 1; + } + + const uint64_t per_token = ds4_residual_hc_floats(); + const uint64_t total = n_tokens * per_token; + if (ds4_metal_tensor_write(s->graph.batch_cur_hc, 0, in, + total * sizeof(float)) == 0) { + if (errlen) snprintf(err, errlen, "ds4_metal_tensor_write failed for batch_cur_hc"); + return 1; + } + + /* Build a dummy token vec of size pos_start + n_tokens. Layers in the + * tail range don't read the token IDs (hash-routed FFNs only exist in + * the first DS4_N_HASH_LAYER layers, all owned by the head), but we + * need a vec large enough that metal_graph_prefill_chunked_range's + * bounds check (start + n_tokens <= prompt->len) passes. */ + const uint64_t dummy_len = (uint64_t)pos_start + n_tokens; + if (dummy_len > INT_MAX) { + if (errlen) snprintf(err, errlen, "eval_batch_imported: pos_start+n_tokens overflows int"); + return 1; + } + int *dummy_tokens = (int *)calloc((size_t)dummy_len, sizeof(int)); + if (!dummy_tokens) { + if (errlen) snprintf(err, errlen, "eval_batch_imported: alloc failed"); + return 1; + } + token_vec dummy_prompt = { + .v = dummy_tokens, + .len = (int)dummy_len, + .cap = (int)dummy_len, + }; + + ds4_engine *e = s->engine; + float *logits_dest = want_logits ? s->logits : NULL; + /* Use chunked_range so KV writes go to absolute position pos_start, not + * 0. prefill_layer_major hardcodes pos_start=0; that overwrote chunks + * 2..N on top of chunk 1 in early Phase 5a testing, producing logits + * that only saw the final chunk's worth of context. chunked_range + * with n_tokens <= prefill_cap runs exactly one internal chunk + * iteration. */ + bool ok = metal_graph_prefill_chunked_range(&s->graph, &e->model, &e->weights, + &dummy_prompt, + pos_start, n_tokens, + logits_dest, false, + NULL, NULL, + NULL, NULL); + free(dummy_tokens); + if (!ok) { + if (errlen) snprintf(err, errlen, "Metal prefill failed in eval_batch_imported"); + s->checkpoint_valid = false; + return 1; + } + + /* Advance the tail's notion of checkpoint length so subsequent chunks + * pass the pos_start sanity check above. We don't store the actual + * tokens -- the tail never reasons about them. */ + for (uint64_t i = 0; i < n_tokens; i++) { + token_vec_push(&s->checkpoint, 0); + } + s->checkpoint_valid = true; + return 0; +} + +int ds4_session_verify_batch_imported_hc(ds4_session *s, + const float *batch_residual, + uint64_t n_tokens, uint32_t pos_start, + const uint32_t *expected_next, + uint32_t n_expected, + uint32_t *out_n_accepted, + float *final_logits, uint64_t n_logit_floats, + char *err, size_t errlen) { + if (!s || !batch_residual || n_tokens == 0 || !out_n_accepted || !final_logits) { + if (errlen) snprintf(err, errlen, "verify_batch: null arg or empty batch"); + return 1; + } + *out_n_accepted = 0; + if (n_expected >= n_tokens) { + if (errlen) snprintf(err, errlen, + "verify_batch: n_expected %u must be < n_tokens %llu", + n_expected, (unsigned long long)n_tokens); + return 1; + } + if (n_logit_floats < DS4_N_VOCAB) { + if (errlen) snprintf(err, errlen, "verify_batch: logit buffer too small"); + return 1; + } + if (n_tokens > s->prefill_cap) { + if (errlen) snprintf(err, errlen, + "verify_batch: n_tokens %llu exceeds prefill_cap %u", + (unsigned long long)n_tokens, s->prefill_cap); + return 1; + } + if (pos_start != (uint32_t)s->checkpoint.len) { + if (errlen) snprintf(err, errlen, + "verify_batch: pos_start=%u but session at %d", + pos_start, s->checkpoint.len); + return 1; + } + + ds4_engine *e = s->engine; + + /* Snapshot KV state so we can roll back on any miss. spec_frontier_* + * covers compressor frontiers, mtp_n_raw, and per-layer compressed cache + * state -- everything that would be left stale by a rejected prefill. */ + ds4_spec_frontier frontier; + memset(&frontier, 0, sizeof(frontier)); + if (!spec_frontier_snapshot(&frontier, s)) { + if (errlen) snprintf(err, errlen, "verify_batch: spec_frontier_snapshot failed"); + return 1; + } + + /* Write the imported residuals into batch_cur_hc, then run prefill via + * the chunked-range function so its layer-range gating and embed/output + * skipping all work (we pass NULL for logits and no chunk hook; the + * output head only runs if this engine owns it, which the tail does). */ + const uint64_t per_token = ds4_residual_hc_floats(); + const uint64_t total_floats = (uint64_t)n_tokens * per_token; + if (ds4_metal_tensor_write(s->graph.batch_cur_hc, 0, batch_residual, + total_floats * sizeof(float)) == 0) { + if (errlen) snprintf(err, errlen, + "verify_batch: tensor_write batch_cur_hc failed"); + spec_frontier_restore(&frontier, s); + spec_frontier_free(&frontier); + return 1; + } + + const uint64_t dummy_len = (uint64_t)pos_start + n_tokens; + if (dummy_len > INT_MAX) { + if (errlen) snprintf(err, errlen, "verify_batch: pos+n overflows int"); + spec_frontier_restore(&frontier, s); + spec_frontier_free(&frontier); + return 1; + } + int *dummy_tokens = (int *)calloc((size_t)dummy_len, sizeof(int)); + if (!dummy_tokens) { + if (errlen) snprintf(err, errlen, "verify_batch: alloc failed"); + spec_frontier_restore(&frontier, s); + spec_frontier_free(&frontier); + return 1; + } + token_vec dummy_prompt = { + .v = dummy_tokens, + .len = (int)dummy_len, + .cap = (int)dummy_len, + }; + + bool ok = metal_graph_prefill_chunked_range(&s->graph, &e->model, &e->weights, + &dummy_prompt, + pos_start, (uint32_t)n_tokens, + /* logits = */ NULL, + false, NULL, NULL, NULL, NULL); + free(dummy_tokens); + if (!ok) { + if (errlen) snprintf(err, errlen, "verify_batch: prefill failed"); + spec_frontier_restore(&frontier, s); + spec_frontier_free(&frontier); + return 1; + } + + /* Now batch_cur_hc has post-layer hidden state for n_tokens rows. Run + * the output head per-row to get logits, compute argmax, compare to + * expected_next. Stop at first miss. Row n_expected (i.e., the last + * position) holds the logits we return on full acceptance for the head + * to sample the next-after-accepted token. */ + bool all_match = true; + float *row_logits = (float *)malloc((size_t)DS4_N_VOCAB * sizeof(float)); + if (!row_logits) { + if (errlen) snprintf(err, errlen, "verify_batch: row buffer alloc failed"); + spec_frontier_restore(&frontier, s); + spec_frontier_free(&frontier); + return 1; + } + for (uint32_t i = 0; i < n_expected; i++) { + if (!metal_graph_prefill_batch_row_logits(&s->graph, &e->model, &e->weights, + i, row_logits)) + { + all_match = false; + break; + } + const int top = sample_argmax(row_logits, DS4_N_VOCAB); + if (top != (int)expected_next[i]) { + all_match = false; + break; + } + } + + if (!all_match) { + free(row_logits); + spec_frontier_restore(&frontier, s); + spec_frontier_free(&frontier); + *out_n_accepted = 0; + return 0; + } + + /* Full accept: read the last row's logits (the prediction for the token + * right AFTER all accepted drafts), advance the checkpoint, and free + * the snapshot. */ + if (!metal_graph_prefill_batch_row_logits(&s->graph, &e->model, &e->weights, + (uint32_t)(n_tokens - 1u), row_logits)) + { + free(row_logits); + spec_frontier_restore(&frontier, s); + spec_frontier_free(&frontier); + if (errlen) snprintf(err, errlen, "verify_batch: final row logits failed"); + return 1; + } + memcpy(final_logits, row_logits, (size_t)DS4_N_VOCAB * sizeof(float)); + free(row_logits); + + for (uint64_t i = 0; i < n_tokens; i++) { + token_vec_push(&s->checkpoint, 0); + } + s->checkpoint_valid = true; + spec_frontier_free(&frontier); + *out_n_accepted = (uint32_t)n_tokens; + return 0; +} +#endif + +uint64_t ds4_session_payload_bytes(ds4_session *s) { +#ifdef DS4_NO_METAL + (void)s; + return 0; +#else + if (!s || !s->checkpoint_valid) return 0; + const ds4_metal_graph *g = &s->graph; + uint64_t bytes = (uint64_t)DS4_SESSION_PAYLOAD_U32_FIELDS * sizeof(uint32_t); + bytes += (uint64_t)s->checkpoint.len * sizeof(uint32_t); + bytes += (uint64_t)DS4_N_VOCAB * sizeof(float); + bytes += (uint64_t)DS4_N_LAYER * sizeof(uint32_t); + bytes += (uint64_t)DS4_N_LAYER * sizeof(uint32_t); + bytes += session_payload_live_tensor_bytes(g, (uint32_t)s->checkpoint.len); + return bytes; +#endif +} + +int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) { +#ifdef DS4_NO_METAL + (void)s; (void)fp; + payload_set_err(err, errlen, "Metal support is not compiled in"); + return 1; +#else + if (!s || !fp || !s->checkpoint_valid) { + payload_set_err(err, errlen, "session has no valid checkpoint to save"); + return 1; + } + if (ds4_metal_synchronize() == 0) { + payload_set_err(err, errlen, "failed to synchronize Metal before snapshot"); + return 1; + } + + ds4_metal_graph *g = &s->graph; + const uint32_t raw_live = session_raw_live_rows(g, (uint32_t)s->checkpoint.len); + /* Header fields: + * 0 magic, 1 version, 2 ctx, 3 prefill chunk, 4 raw cap, + * 5 raw window, 6 compressed cap, 7 token count, + * 8 layers, 9 raw head dim, 10 indexer head dim, 11 vocab, + * 12 live raw rows serialized below. + */ + uint32_t header[DS4_SESSION_PAYLOAD_U32_FIELDS] = { + DS4_SESSION_PAYLOAD_MAGIC, + DS4_SESSION_PAYLOAD_VERSION, + (uint32_t)s->ctx_size, + s->prefill_cap, + g->raw_cap, + g->raw_window, + g->comp_cap, + (uint32_t)s->checkpoint.len, + DS4_N_LAYER, + DS4_N_HEAD_DIM, + DS4_N_INDEXER_HEAD_DIM, + DS4_N_VOCAB, + raw_live, + }; + for (uint32_t i = 0; i < DS4_SESSION_PAYLOAD_U32_FIELDS; i++) { + if (payload_write_u32(fp, header[i], err, errlen) != 0) return 1; + } + for (int i = 0; i < s->checkpoint.len; i++) { + if (payload_write_u32(fp, (uint32_t)s->checkpoint.v[i], err, errlen) != 0) return 1; + } + if (payload_write_bytes(fp, s->logits, (uint64_t)DS4_N_VOCAB * sizeof(float), err, errlen) != 0) return 1; + for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + if (payload_write_u32(fp, g->layer_n_comp[il], err, errlen) != 0) return 1; + } + for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + if (payload_write_u32(fp, g->layer_n_index_comp[il], err, errlen) != 0) return 1; + } + + uint8_t *buf = xmalloc(DS4_SESSION_IO_CHUNK); + int rc = 0; + for (uint32_t il = 0; rc == 0 && il < DS4_N_LAYER; il++) { /* Write the raw ring in logical position order. The file does not care * where the rows happened to live physically in the source graph. */ const uint32_t raw_first = (uint32_t)s->checkpoint.len - raw_live; @@ -15704,6 +16472,23 @@ int ds4_engine_first_token_test(ds4_engine *e, const ds4_tokens *prompt) { return 0; } +/* Open the GGUF just enough to compute the cheap handshake fingerprint: file + * size and the first 32 bytes (the GGUF magic and version, which differ + * across model variants). Both head and worker compute this independently; + * a mismatch fails the handshake with a clear error. */ +static int rpc_compute_model_fingerprint(const char *path, + uint64_t *out_bytes, + uint8_t out_sample[32]) { + struct stat st; + if (stat(path, &st) != 0) return 1; + *out_bytes = (uint64_t)st.st_size; + int fd = open(path, O_RDONLY); + if (fd < 0) return 1; + ssize_t r = read(fd, out_sample, 32); + close(fd); + return r == 32 ? 0 : 1; +} + int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { ds4_engine *e = xcalloc(1, sizeof(*e)); e->model.fd = -1; @@ -15713,6 +16498,20 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { e->mtp_draft_tokens = opt->mtp_draft_tokens > 0 ? opt->mtp_draft_tokens : 1; if (e->mtp_draft_tokens > 16) e->mtp_draft_tokens = 16; e->mtp_margin = opt->mtp_margin >= 0.0f ? opt->mtp_margin : 3.0f; + { + int ls = opt->n_layer_start; + int le = opt->n_layer_end <= 0 ? DS4_N_LAYER : opt->n_layer_end; + if (ls < 0 || ls >= le || le > DS4_N_LAYER) { + fprintf(stderr, + "ds4: invalid layer range [%d, %d); must satisfy 0 <= start < end <= %d\n", + opt->n_layer_start, opt->n_layer_end, DS4_N_LAYER); + free(e); + *out = NULL; + return 1; + } + e->n_layer_start = (uint32_t)ls; + e->n_layer_end = (uint32_t)le; + } if (opt->n_threads > 0) g_requested_threads = (uint32_t)opt->n_threads; ds4_acquire_instance_lock(); @@ -15721,7 +16520,20 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { if (opt->warm_weights) model_warm_weights(&e->model); vocab_load(&e->vocab, &e->model); config_validate_model(&e->model); - weights_bind(&e->weights, &e->model); + { + /* Tail engines need token_embd bound iff they'll run MTP drafting + * (the draft step embeds the predicted token via that tensor). + * Single-host and head engines already own it via n_layer_start==0. */ + const bool mtp_token_embd = + (opt->mtp_path != NULL && opt->mtp_path[0] != '\0'); + weights_bind(&e->weights, &e->model, e->n_layer_start, e->n_layer_end, + mtp_token_embd); + } + if (e->n_layer_start != 0 || e->n_layer_end != DS4_N_LAYER) { + fprintf(stderr, + "ds4: pipeline-parallel weight binding active, layer range [%u, %u) of %u\n", + e->n_layer_start, e->n_layer_end, DS4_N_LAYER); + } if (opt->mtp_path && opt->mtp_path[0]) { model_open(&e->mtp_model, opt->mtp_path, opt->backend == DS4_BACKEND_METAL, true); @@ -15742,18 +16554,116 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { return 1; } ds4_metal_set_quality(e->quality); - if (!ds4_metal_set_model_map_range(e->model.map, - e->model.size, - e->model.tensor_data_pos, - e->model.size - e->model.tensor_data_pos)) - { - fprintf(stderr, - "ds4: Metal failed to map model views; aborting startup. " - "This is commonly caused by insufficient memory or Metal VM budget.\n"); + /* Compute disjoint byte clusters covering the bound tensors and map + * each separately. A single min/max range over Q4 head weights + * spans ~149 GiB because token_embd sits at the end of the file + * while attention/FFN tensors sit near the start -- the union + * includes 70+ GiB of the tail's weights in the middle. Clustered + * mapping covers only what this engine actually reads, which is + * what makes pipeline-parallel Q4 fit on 128 GiB at all. */ + ds4_byte_range clusters[DS4_MAX_BYTE_CLUSTERS]; + const int n_clusters = weights_compute_byte_clusters(&e->weights, + clusters, + DS4_MAX_BYTE_CLUSTERS); + if (n_clusters <= 0) { + fprintf(stderr, "ds4: no bound tensors to map; aborting startup\n"); ds4_engine_close(e); *out = NULL; return 1; } + + uint64_t total_map_bytes = 0; + for (int i = 0; i < n_clusters; i++) { + if (clusters[i].offset < e->model.tensor_data_pos) { + /* Should never happen for a well-formed GGUF, but clip + * defensively so map_model_views can't read header bytes. */ + const uint64_t shift = e->model.tensor_data_pos - clusters[i].offset; + if (shift >= clusters[i].bytes) { + clusters[i].bytes = 0; + continue; + } + clusters[i].offset += shift; + clusters[i].bytes -= shift; + } + if (clusters[i].offset + clusters[i].bytes > e->model.size) { + clusters[i].bytes = e->model.size - clusters[i].offset; + } + total_map_bytes += clusters[i].bytes; + } + + if (e->n_layer_start != 0 || e->n_layer_end != DS4_N_LAYER) { + fprintf(stderr, + "ds4: pipeline-parallel mapping for layer range [%u, %u): " + "%d cluster(s), %.2f GiB total\n", + (unsigned)e->n_layer_start, (unsigned)e->n_layer_end, + n_clusters, + (double)total_map_bytes / (1024.0 * 1024.0 * 1024.0)); + for (int i = 0; i < n_clusters; i++) { + fprintf(stderr, + "ds4: cluster %d: [%llu, %llu) = %.2f GiB\n", + i, + (unsigned long long)clusters[i].offset, + (unsigned long long)(clusters[i].offset + clusters[i].bytes), + (double)clusters[i].bytes / (1024.0 * 1024.0 * 1024.0)); + } + } + + /* Hard guard: refuse to wire more bytes than the system has physical + * RAM for. On Apple Silicon the Metal model mapping wires tensors + * into unified memory; mapping ~95% of RAM leaves no headroom for + * the kernel, KV caches, scratch buffers, or anything else. + * Previously kernel-panicked a 128 GB Mac on Q4. */ + { + uint64_t phys_ram = 0; + size_t phys_len = sizeof(phys_ram); + if (sysctlbyname("hw.memsize", &phys_ram, &phys_len, NULL, 0) == 0 && + phys_ram > 0) + { + const uint64_t cap = phys_ram - (phys_ram / 16); /* 93.75% */ + if (total_map_bytes > cap) { + fprintf(stderr, + "ds4: refusing to map %.2f GiB of tensor data on a " + "system with only %.2f GiB physical RAM (cap %.2f GiB).\n" + "ds4: this would wire the model into unified memory " + "and likely kernel-panic macOS.\n" + "ds4: options: (1) --quant q2 fits on 128 GB; " + "(2) run pipeline-parallel with --rpc-peer + " + "--rpc-split so this engine owns only part of " + "the layers.\n", + (double)total_map_bytes / (1024.0 * 1024.0 * 1024.0), + (double)phys_ram / (1024.0 * 1024.0 * 1024.0), + (double)cap / (1024.0 * 1024.0 * 1024.0)); + ds4_engine_close(e); + *out = NULL; + return 1; + } + } + } + + /* Map each cluster separately. ds4_metal_set_model_map_range + * appends views to its internal g_model_views array; the lookup + * path (ds4_metal_wrap_model_range) walks that array to find which + * buffer holds any given tensor. Disjoint mappings are first-class + * here. A redundant residency-clear/reset happens between calls, + * which is wasteful but harmless. */ + for (int i = 0; i < n_clusters; i++) { + if (clusters[i].bytes == 0) continue; + if (!ds4_metal_set_model_map_range(e->model.map, + e->model.size, + clusters[i].offset, + clusters[i].bytes)) + { + fprintf(stderr, + "ds4: Metal failed to map cluster %d " + "([%llu, %llu)); aborting startup.\n", + i, + (unsigned long long)clusters[i].offset, + (unsigned long long)(clusters[i].offset + clusters[i].bytes)); + ds4_engine_close(e); + *out = NULL; + return 1; + } + } if (e->mtp_ready && !ds4_metal_set_model_map_range(e->mtp_model.map, e->mtp_model.size, @@ -15778,6 +16688,88 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { } #endif + /* Optional pipeline-parallel head: dial the configured tail worker over + * TCP, exchange a config handshake, and stash the connection so decode + * paths can ship residuals. Failure here aborts startup since proceeding + * without the second machine would silently produce wrong output. */ + if (opt->rpc_peer_host && opt->rpc_peer_host[0]) { + if (e->n_layer_end == DS4_N_LAYER) { + fprintf(stderr, + "ds4: --rpc-peer is set but this engine owns all %u layers; " + "set n_layer_end (e.g. 22) so the peer can own the rest\n", + (unsigned)DS4_N_LAYER); + ds4_engine_close(e); + *out = NULL; + return 1; + } + + uint64_t model_bytes = 0; + uint8_t model_sample[32] = {0}; + if (rpc_compute_model_fingerprint(opt->model_path, &model_bytes, model_sample) != 0) { + fprintf(stderr, "ds4: failed to fingerprint model file for RPC handshake: %s\n", + strerror(errno)); + ds4_engine_close(e); + *out = NULL; + return 1; + } + + if (opt->rpc_ctx_size <= 0) { + fprintf(stderr, + "ds4: --rpc-peer is set but rpc_ctx_size is 0; " + "pass --ctx to the head and propagate it via engine options\n"); + ds4_engine_close(e); + *out = NULL; + return 1; + } + ds4_rpc_config cfg = { + .version = DS4_RPC_VERSION, + .n_layer_total = DS4_N_LAYER, + .n_embd = DS4_N_EMBD, + .n_hc = DS4_N_HC, + .n_vocab = DS4_N_VOCAB, + .routed_quant_bits = (uint32_t)ds4_engine_routed_quant_bits(e), + .tail_layer_start = e->n_layer_end, + .tail_layer_end = DS4_N_LAYER, + .ctx_size = (uint32_t)opt->rpc_ctx_size, + .model_file_bytes = model_bytes, + }; + memcpy(cfg.model_sample, model_sample, 32); + + const uint16_t port = (opt->rpc_peer_port > 0 && opt->rpc_peer_port < 65536) + ? (uint16_t)opt->rpc_peer_port : 46434u; + + char rpc_err[512] = {0}; + if (ds4_rpc_dial(opt->rpc_peer_host, port, &e->rpc_peer, + rpc_err, sizeof(rpc_err)) != 0) { + fprintf(stderr, "ds4: rpc dial %s:%u failed: %s\n", + opt->rpc_peer_host, (unsigned)port, rpc_err); + ds4_engine_close(e); + *out = NULL; + return 1; + } + ds4_rpc_config peer_cfg = {0}; + if (ds4_rpc_handshake_client_peer(e->rpc_peer, &cfg, &peer_cfg, + rpc_err, sizeof(rpc_err)) != 0) { + fprintf(stderr, "ds4: rpc handshake with %s:%u failed: %s\n", + opt->rpc_peer_host, (unsigned)port, rpc_err); + ds4_engine_close(e); + *out = NULL; + return 1; + } + e->rpc_tail_has_mtp = peer_cfg.tail_has_mtp != 0; + e->rpc_tail_mtp_drafts = peer_cfg.tail_mtp_draft_tokens; + /* If the peer offers MTP, adopt its draft count as our effective + * mtp_draft_tokens (the head uses this in speculative_argmax). */ + if (e->rpc_tail_has_mtp && e->mtp_draft_tokens < (int)e->rpc_tail_mtp_drafts) { + e->mtp_draft_tokens = (int)e->rpc_tail_mtp_drafts; + } + fprintf(stderr, + "ds4: pipeline-parallel head connected to %s:%u, peer owns layers [%u, %u)%s\n", + opt->rpc_peer_host, (unsigned)port, + (unsigned)e->n_layer_end, (unsigned)DS4_N_LAYER, + e->rpc_tail_has_mtp ? " (MTP available)" : ""); + } + *out = e; return 0; } @@ -15788,6 +16780,11 @@ void ds4_engine_summary(ds4_engine *e) { void ds4_engine_close(ds4_engine *e) { if (!e) return; + if (e->rpc_peer) { + (void)ds4_rpc_shutdown_send(e->rpc_peer); + ds4_rpc_close(e->rpc_peer); + e->rpc_peer = NULL; + } weights_free(&e->weights); vocab_free(&e->vocab); ds4_threads_shutdown(); @@ -15814,8 +16811,17 @@ int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size) { s->ctx_size = ctx_size; s->prefill_cap = metal_graph_prefill_cap_for_prompt(ctx_size); const uint32_t raw_cap = metal_graph_raw_cap_for_context(ctx_size, s->prefill_cap); - if (!metal_graph_alloc_raw_cap(&s->graph, &e->weights, &e->weights.layer[0], - raw_cap, (uint32_t)ctx_size, s->prefill_cap, e->mtp_ready)) + /* Use the first OWNED layer as the dimension probe. A pipeline-parallel + * tail engine has layer[0] zeroed; passing it segfaults at the first + * `layer->attn_q_a->dim[1]` dereference. */ + const uint32_t probe_il = e->weights.n_layer_start < DS4_N_LAYER + ? e->weights.n_layer_start : 0; + /* Head under RPC needs spec_* buffers to snapshot/restore the residual + * stream around a batched verify, even though it doesn't hold the MTP + * weights itself — the tail does. */ + const bool need_mtp_buffers = e->mtp_ready || e->rpc_tail_has_mtp; + if (!metal_graph_alloc_raw_cap(&s->graph, &e->weights, &e->weights.layer[probe_il], + raw_cap, (uint32_t)ctx_size, s->prefill_cap, need_mtp_buffers)) { free(s); return 1; @@ -15826,19 +16832,36 @@ int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size) { s->mtp_logits = xmalloc((size_t)DS4_N_VOCAB * sizeof(s->mtp_logits[0])); s->mtp_draft_token = -1; } + if (e->rpc_peer) { + s->rpc_residual_scratch = xmalloc((size_t)ds4_residual_hc_floats() * + sizeof(s->rpc_residual_scratch[0])); + } *out = s; return 0; #endif } +#ifndef DS4_NO_METAL +static int rpc_spec_abort(ds4_session *s, char *err, size_t errlen); +#endif + void ds4_session_free(ds4_session *s) { if (!s) return; #ifndef DS4_NO_METAL + /* Drain any Phase 6 in-flight speculative reply before tearing down, + * otherwise the next session created on the same RPC connection will + * see the stale reply as its first frame. We can't rewind the tail + * here since the session is going away; abort() does both anyway. */ + if (s->rpc_spec_in_flight) { + char err[256] = {0}; + (void)rpc_spec_abort(s, err, sizeof(err)); + } metal_graph_free(&s->graph); #endif token_vec_free(&s->checkpoint); free(s->logits); free(s->mtp_logits); + free(s->rpc_residual_scratch); free(s); } @@ -15882,6 +16905,126 @@ static void ds4_session_note_prefill_progress(void *ud, const char *event, int c * * A non-matching prompt discards the checkpoint and prefills from token zero. */ +static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, + char *err, size_t errlen); +static void rpc_propagate_reset(ds4_session *s, const char *origin); + +#ifndef DS4_NO_METAL +/* Per-chunk hook for pipeline-parallel prefill on the head. Reads the + * batch_cur_hc tensor for the just-finished chunk, ships it to the tail + * via OP_PREFILL_REQ, and (on the last chunk) writes the returned logits + * into the session's logits buffer. Returns 0 on success, 1 on error so + * metal_graph_prefill_chunked_range can abort cleanly. */ +typedef struct { + ds4_session *session; + float *scratch; /* batch_cur_hc readback buffer */ + uint64_t scratch_floats; /* capacity of scratch in floats */ + int rc; /* preserved error code across chunks */ + char *err; + size_t errlen; +} rpc_chunk_state; + +static int rpc_chunk_hook(void *user, ds4_metal_graph *g, + uint32_t pos0, uint32_t n_chunk_tokens, + bool is_last_chunk) { + rpc_chunk_state *st = (rpc_chunk_state *)user; + ds4_session *s = st->session; + ds4_engine *e = s->engine; + if (!e->rpc_peer) { + if (st->errlen) snprintf(st->err, st->errlen, + "rpc chunk hook fired without an attached peer"); + st->rc = 1; + return 1; + } + const uint64_t per_token = ds4_residual_hc_floats(); + const uint64_t chunk_floats = (uint64_t)n_chunk_tokens * per_token; + if (chunk_floats > st->scratch_floats) { + if (st->errlen) snprintf(st->err, st->errlen, + "rpc chunk hook: chunk %u tokens exceeds scratch %llu floats", + n_chunk_tokens, (unsigned long long)st->scratch_floats); + st->rc = 1; + return 1; + } + if (ds4_metal_tensor_read(g->batch_cur_hc, 0, st->scratch, + chunk_floats * sizeof(float)) == 0) { + if (st->errlen) snprintf(st->err, st->errlen, + "rpc chunk hook: tensor_read batch_cur_hc failed"); + st->rc = 1; + return 1; + } + + float *out_logits = is_last_chunk ? s->logits : NULL; + const uint64_t out_floats = is_last_chunk ? (uint64_t)DS4_N_VOCAB : 0u; + + char rpc_err[512] = {0}; + if (ds4_rpc_prefill_request(e->rpc_peer, + n_chunk_tokens, pos0, is_last_chunk, + st->scratch, chunk_floats, + out_logits, out_floats, + rpc_err, sizeof(rpc_err)) != 0) { + /* Drop the dead peer so subsequent ops fail fast. See same pattern + * in ds4_session_eval_internal's decode path. */ + ds4_rpc_close(e->rpc_peer); + e->rpc_peer = NULL; + if (st->errlen) snprintf(st->err, st->errlen, + "rpc prefill chunk at pos=%u: %s " + "(connection dropped)", pos0, rpc_err); + st->rc = 1; + return 1; + } + return 0; +} + +/* Top-level RPC-aware chunked prefill. Resets the tail's session state + * first, then drives metal_graph_prefill_chunked_range with the RPC hook + * so each chunk's residual streams across as it's produced. The whole + * prompt is processed in chunks of g->prefill_cap tokens. */ +static int rpc_batched_prefill(ds4_session *s, const token_vec *prompt, + uint32_t start, uint32_t n_tokens, + char *err, size_t errlen) { + ds4_engine *e = s->engine; + if (!e->rpc_peer) { + snprintf(err, errlen, "rpc_batched_prefill: no peer attached"); + return 1; + } + if (n_tokens == 0) return 0; + + const uint64_t per_token = ds4_residual_hc_floats(); + const uint64_t scratch_floats = (uint64_t)s->graph.prefill_cap * per_token; + float *scratch = (float *)malloc((size_t)(scratch_floats * sizeof(float))); + if (!scratch) { + snprintf(err, errlen, "rpc_batched_prefill: scratch alloc (%llu floats) failed", + (unsigned long long)scratch_floats); + return 1; + } + + rpc_chunk_state st = { + .session = s, + .scratch = scratch, + .scratch_floats = scratch_floats, + .rc = 0, + .err = err, + .errlen = errlen, + }; + + bool ok = metal_graph_prefill_chunked_range(&s->graph, &e->model, &e->weights, + prompt, start, n_tokens, + /* logits */ NULL, + /* show_progress */ false, + /* progress */ NULL, NULL, + rpc_chunk_hook, &st); + free(scratch); + if (!ok) { + if (st.rc == 0 && errlen) { + snprintf(err, errlen, + "rpc_batched_prefill: head prefill failed before/after RPC ship"); + } + return st.rc != 0 ? st.rc : 1; + } + return 0; +} +#endif + int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t errlen) { #ifdef DS4_NO_METAL (void)s; @@ -15895,14 +17038,33 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t return 1; } + /* Pipeline-parallel prefill ships batch_cur_hc to the tail per chunk + * via OP_PREFILL_REQ. For small extensions (the suffix-extension path + * below) it is sometimes cheaper to fall back to per-token decode RPCs + * than to spin up batch machinery, but that thresholding is a future + * polish item; for now any RPC engine uses the batched path. */ + const bool use_batched_rpc = (e->rpc_peer != NULL); + if (s->checkpoint_valid && prompt->len >= s->checkpoint.len && ds4_tokens_starts_with(prompt, &s->checkpoint)) { s->mtp_draft_valid = false; const int suffix = prompt->len - s->checkpoint.len; + if (use_batched_rpc && suffix > 0) { + if (rpc_batched_prefill(s, prompt, + (uint32_t)s->checkpoint.len, + (uint32_t)suffix, + err, errlen) != 0) { + s->checkpoint_valid = false; + return 1; + } + ds4_tokens_copy(&s->checkpoint, prompt); + s->checkpoint_valid = true; + return 0; + } const uint32_t resume_min = metal_graph_resume_prefill_min_tokens(); - if (suffix > 0 && (uint32_t)suffix >= resume_min) { + if (!use_batched_rpc && suffix > 0 && (uint32_t)suffix >= resume_min) { ds4_sync_progress progress = { .session = s, .prompt = prompt, @@ -15920,7 +17082,8 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t s->logits, false, progress_fn, - progress_fn ? &progress : NULL); + progress_fn ? &progress : NULL, + NULL, NULL); if (!ok) { snprintf(err, errlen, "Metal resumed prefill failed while extending checkpoint"); s->checkpoint_valid = false; @@ -15943,6 +17106,24 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t } token_vec_push(&s->checkpoint, prompt->v[i]); } + s->checkpoint_valid = true; + return 0; + } + + if (use_batched_rpc) { + /* Cold path under RPC: reset tail KV (so its checkpoint counter + * starts at 0, matching pos_start=0 below) and drive the head + * through chunked prefill with the RPC hook shipping each chunk. */ + rpc_propagate_reset(s, "sync cold path"); + s->checkpoint.len = 0; + if (rpc_batched_prefill(s, prompt, 0, (uint32_t)prompt->len, + err, errlen) != 0) { + s->checkpoint_valid = false; + return 1; + } + ds4_tokens_copy(&s->checkpoint, prompt); + s->checkpoint_valid = true; + s->mtp_draft_valid = false; return 0; } @@ -16097,6 +17278,56 @@ int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k) { return k; } +#ifndef DS4_NO_METAL +/* Phase 6 head-side speculative prefetch: drain a stale in-flight reply, + * rewind the tail past the speculative token, and restore head KV from the + * snapshot taken before the speculative L0-L20. Used on mispredict and on + * any session boundary (invalidate / rewind / shutdown) that needs to clear + * the speculative state before doing other RPC work. + * + * Returns 0 on success. On failure the connection is closed and + * `rpc_peer` is NULL-ed; rpc_spec_in_flight is always cleared. */ +static int rpc_spec_abort(ds4_session *s, char *err, size_t errlen) { + if (!s || !s->rpc_spec_in_flight) return 0; + ds4_engine *e = s->engine; + if (!e || !e->rpc_peer) { + spec_frontier_free(&s->rpc_spec_snapshot); + s->rpc_spec_in_flight = false; + return 0; + } + s->rpc_spec_miss++; + char rpc_err[256] = {0}; + int rc = 0; + if (ds4_rpc_decode_recv_reply(e->rpc_peer, + s->logits, DS4_N_VOCAB, + NULL, 0, NULL, + rpc_err, sizeof(rpc_err)) != 0) { + if (err && errlen) { + snprintf(err, errlen, "rpc-spec: drain reply: %s", rpc_err); + } + ds4_rpc_close(e->rpc_peer); + e->rpc_peer = NULL; + rc = 1; + } else if (ds4_rpc_rewind(e->rpc_peer, s->rpc_spec_pos, + rpc_err, sizeof(rpc_err)) != 0) { + if (err && errlen) { + snprintf(err, errlen, "rpc-spec: rewind tail: %s", rpc_err); + } + ds4_rpc_close(e->rpc_peer); + e->rpc_peer = NULL; + rc = 1; + } else if (!spec_frontier_restore(&s->rpc_spec_snapshot, s)) { + if (err && errlen) { + snprintf(err, errlen, "rpc-spec: restore head KV failed"); + } + rc = 1; + } + spec_frontier_free(&s->rpc_spec_snapshot); + s->rpc_spec_in_flight = false; + return rc; +} +#endif + static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, char *err, size_t errlen) { #ifdef DS4_NO_METAL @@ -16108,8 +17339,10 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, #else ds4_engine *e = s->engine; const bool mtp_probe_log = getenv("DS4_MTP_PROBE") != NULL; + /* MTP is disabled under RPC for now: the draft path spans all layers and + * we have not extended the wire protocol to cooperate. Phase 4 follow-up. */ const bool mtp_should_draft = - probe_mtp && e->mtp_ready && s->mtp_logits && + probe_mtp && e->mtp_ready && s->mtp_logits && e->rpc_peer == NULL && (e->mtp_draft_tokens > 1 || mtp_probe_log); if (probe_mtp && s->mtp_draft_valid) { if (mtp_probe_log) { @@ -16124,16 +17357,220 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, } s->mtp_draft_valid = false; } - if (!metal_graph_eval_token_raw_swa(&s->graph, &e->model, &e->weights, - (uint32_t)token, - (uint32_t)s->checkpoint.len, - s->logits)) - { - snprintf(err, errlen, "Metal decode failed"); - s->checkpoint_valid = false; - return 1; + /* Single-host path: head owns all layers, so just eval and produce + * logits locally. Under RPC the eval is interleaved with ship/recv and + * gated by the prefetch hit/miss state below. */ + if (!e->rpc_peer) { + if (!metal_graph_eval_token_raw_swa(&s->graph, &e->model, &e->weights, + (uint32_t)token, + (uint32_t)s->checkpoint.len, + s->logits)) + { + snprintf(err, errlen, "Metal decode failed"); + s->checkpoint_valid = false; + return 1; + } + } else { + /* Pipeline-parallel head: ship the residual to the tail worker and + * let it run the remaining layers + output head; the returned logits + * land in s->logits exactly as if we had run the full graph locally. + * + * Phase 6 prefetch (DS4_RPC_PREFETCH=1): + * - hit: prev call already shipped this token speculatively; + * head KV is already advanced -- skip L0-L20 and just + * collect the pending reply. + * - miss: prev call shipped a different token speculatively; + * drain that reply, rewind tail, restore head KV, then + * run the normal sync ship/recv. + * - none: no spec in flight; normal sync ship/recv. */ + const uint64_t n_residual = ds4_residual_hc_floats(); + float *residual = s->rpc_residual_scratch; + if (!residual) { + snprintf(err, errlen, "rpc: residual scratch missing on session"); + s->checkpoint_valid = false; + return 1; + } + const bool ask_drafts = probe_mtp && e->rpc_tail_has_mtp && + e->mtp_draft_tokens > 0; + const bool prefetch_on = ask_drafts && getenv("DS4_RPC_PREFETCH") != NULL; + const bool spec_debug = getenv("DS4_RPC_SPEC_DEBUG") != NULL; + char rpc_err[512] = {0}; + uint32_t drafts_buf[DS4_RPC_MAX_DRAFTS] = {0}; + uint32_t n_drafts_returned = 0; + + const bool spec_hit = s->rpc_spec_in_flight && + s->rpc_spec_predicted_token == (uint32_t)token && + s->rpc_spec_pos == (uint32_t)s->checkpoint.len; + + if (spec_debug) { + fprintf(stderr, "ds4-spec: eval enter token=%d pos=%d in_flight=%d pred=%u spec_pos=%u hit=%d cooldown=%u\n", + token, s->checkpoint.len, (int)s->rpc_spec_in_flight, + s->rpc_spec_predicted_token, s->rpc_spec_pos, (int)spec_hit, + s->rpc_spec_cooldown); + } + + /* Phase 6.7 adaptive: record this cycle's outcome (only for cycles + * that had a spec in flight to evaluate). When the window fills, + * check the hit rate; if too low, enter cooldown. */ + if (s->rpc_spec_in_flight) { + s->rpc_spec_history = (s->rpc_spec_history << 1) | (spec_hit ? 1u : 0u); + if (s->rpc_spec_attempts < 32u) s->rpc_spec_attempts++; + if (s->rpc_spec_attempts == 32u) { + const uint32_t hits = (uint32_t)__builtin_popcount(s->rpc_spec_history); + if (hits < 16u) { + /* < 50% hit rate over 32 cycles: cool off for 32 cycles. */ + s->rpc_spec_cooldown = 32u; + s->rpc_spec_history = 0; + s->rpc_spec_attempts = 0; + if (spec_debug) { + fprintf(stderr, "ds4-spec: adaptive: hit rate %u/32 below threshold, " + "entering cooldown for 32 cycles\n", hits); + } + } + } + } + if (s->rpc_spec_cooldown > 0) s->rpc_spec_cooldown--; + + if (spec_hit) { + /* HIT: head KV already advanced through this token in the prev + * eval's speculative L0-L20. Just collect the spec reply. */ + s->rpc_spec_hit++; + if (spec_debug) fprintf(stderr, "ds4-spec: HIT path -- recv_reply\n"); + if (ds4_rpc_decode_recv_reply(e->rpc_peer, + s->logits, DS4_N_VOCAB, + ask_drafts ? drafts_buf : NULL, + ask_drafts ? DS4_RPC_MAX_DRAFTS : 0u, + ask_drafts ? &n_drafts_returned : NULL, + rpc_err, sizeof(rpc_err)) != 0) { + ds4_rpc_close(e->rpc_peer); + e->rpc_peer = NULL; + spec_frontier_free(&s->rpc_spec_snapshot); + s->rpc_spec_in_flight = false; + snprintf(err, errlen, "rpc-spec: hit recv_reply: %s " + "(connection dropped)", rpc_err); + s->checkpoint_valid = false; + return 1; + } + spec_frontier_free(&s->rpc_spec_snapshot); + s->rpc_spec_in_flight = false; + } else { + /* MISS path: abort any in-flight spec, then sync ship/recv. */ + if (spec_debug) fprintf(stderr, "ds4-spec: MISS/NORMAL path; in_flight=%d\n", + (int)s->rpc_spec_in_flight); + if (s->rpc_spec_in_flight) { + char abort_err[256] = {0}; + if (rpc_spec_abort(s, abort_err, sizeof(abort_err)) != 0) { + snprintf(err, errlen, "rpc-spec: abort on miss: %s", + abort_err); + s->checkpoint_valid = false; + return 1; + } + } + if (!metal_graph_eval_token_raw_swa(&s->graph, &e->model, &e->weights, + (uint32_t)token, + (uint32_t)s->checkpoint.len, + s->logits)) + { + snprintf(err, errlen, "Metal decode failed"); + s->checkpoint_valid = false; + return 1; + } + if (ds4_session_export_residual_hc(s, residual, n_residual, + rpc_err, sizeof(rpc_err)) != 0) { + snprintf(err, errlen, "rpc: export residual: %s", rpc_err); + s->checkpoint_valid = false; + return 1; + } + if (ds4_rpc_decode_request(e->rpc_peer, + (uint32_t)token, + (uint32_t)s->checkpoint.len, + ask_drafts, + residual, n_residual, + s->logits, DS4_N_VOCAB, + ask_drafts ? drafts_buf : NULL, + ask_drafts ? DS4_RPC_MAX_DRAFTS : 0u, + ask_drafts ? &n_drafts_returned : NULL, + rpc_err, sizeof(rpc_err)) != 0) { + ds4_rpc_close(e->rpc_peer); + e->rpc_peer = NULL; + snprintf(err, errlen, "rpc: decode request to tail: %s " + "(connection dropped)", rpc_err); + s->checkpoint_valid = false; + return 1; + } + } + + /* Store drafts (same shape as before; speculative_argmax reads + * them). Under prefetch, the batched-verify path in + * speculative_argmax must be disabled or it will replay the spec + * work — handled in Phase 6.4. */ + if (ask_drafts && n_drafts_returned > 0) { + s->mtp_draft_token = (int)drafts_buf[0]; + s->mtp_draft_valid = true; + s->rpc_n_extra_drafts = n_drafts_returned > 1 ? n_drafts_returned - 1 : 0; + for (uint32_t i = 1; i < n_drafts_returned && i - 1 < DS4_RPC_MAX_DRAFTS; i++) { + s->rpc_extra_drafts[i - 1] = drafts_buf[i]; + } + } + + /* Phase 6 prefetch START: if we just got drafts and don't already + * have a spec in flight, snapshot head KV, run L0-L20 on drafts[0] + * speculatively, and ship the request. The next eval call will + * either hit (token matches drafts[0]) or miss (abort + sync). */ + const bool adaptive_block = (s->rpc_spec_cooldown > 0); + if (prefetch_on && adaptive_block && spec_debug) { + fprintf(stderr, "ds4-spec: adaptive: skip prefetch start (cooldown=%u)\n", + s->rpc_spec_cooldown); + } + if (prefetch_on && !adaptive_block && !s->rpc_spec_in_flight && n_drafts_returned > 0) { + const uint32_t pred = drafts_buf[0]; + const uint32_t spec_pos = (uint32_t)(s->checkpoint.len + 1); + if (spec_debug) { + fprintf(stderr, "ds4-spec: prefetch start pred=%u spec_pos=%u ctx_size=%d\n", + pred, spec_pos, s->ctx_size); + } + if (spec_pos >= (uint32_t)s->ctx_size) { + if (spec_debug) fprintf(stderr, "ds4-spec: skip prefetch (ctx full)\n"); + } else if (!spec_frontier_snapshot(&s->rpc_spec_snapshot, s)) { + if (spec_debug) fprintf(stderr, "ds4-spec: snapshot FAILED\n"); + } else { + if (spec_debug) fprintf(stderr, "ds4-spec: snapshot ok; running spec L0-L20\n"); + bool ok = metal_graph_eval_token_raw_swa(&s->graph, &e->model, + &e->weights, + pred, spec_pos, + s->logits); + if (spec_debug) fprintf(stderr, "ds4-spec: spec eval ok=%d\n", (int)ok); + if (ok) { + ok = ds4_session_export_residual_hc(s, residual, n_residual, + rpc_err, sizeof(rpc_err)) == 0; + if (spec_debug) fprintf(stderr, "ds4-spec: export ok=%d err=%s\n", (int)ok, rpc_err); + } + if (ok) { + ok = ds4_rpc_decode_send(e->rpc_peer, pred, spec_pos, + true, + residual, n_residual, + rpc_err, sizeof(rpc_err)) == 0; + if (spec_debug) fprintf(stderr, "ds4-spec: send ok=%d err=%s\n", (int)ok, rpc_err); + } + if (ok) { + s->rpc_spec_in_flight = true; + s->rpc_spec_predicted_token = pred; + s->rpc_spec_pos = spec_pos; + if (spec_debug) fprintf(stderr, "ds4-spec: prefetch armed in_flight=1\n"); + } else { + if (spec_debug) fprintf(stderr, "ds4-spec: prefetch FAILED -- restoring\n"); + (void)spec_frontier_restore(&s->rpc_spec_snapshot, s); + spec_frontier_free(&s->rpc_spec_snapshot); + } + } + } + if (spec_debug) fprintf(stderr, "ds4-spec: end of rpc_peer block (about to push)\n"); } + token_vec_push(&s->checkpoint, token); + if (e && e->rpc_peer && getenv("DS4_RPC_SPEC_DEBUG")) { + fprintf(stderr, "ds4-spec: pushed token=%d new_len=%d\n", token, s->checkpoint.len); + } if (mtp_should_draft) { int mtp_top = -1; if (metal_graph_eval_mtp_draft(&s->graph, @@ -16151,6 +17588,10 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, fprintf(stderr, "ds4: mtp probe draft failed\n"); } } + if (e && e->rpc_peer && getenv("DS4_RPC_SPEC_DEBUG")) { + fprintf(stderr, "ds4-spec: eval_internal returning 0 (len=%d in_flight=%d)\n", + s->checkpoint.len, (int)s->rpc_spec_in_flight); + } return 0; #endif } @@ -16159,6 +17600,78 @@ int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen) { return ds4_session_eval_internal(s, token, true, err, errlen); } +int ds4_session_eval_no_draft(ds4_session *s, int token, char *err, size_t errlen) { + return ds4_session_eval_internal(s, token, false, err, errlen); +} + +#ifndef DS4_NO_METAL +int ds4_session_mtp_drafts_after_eval(ds4_session *s, + uint32_t *out_drafts, int max_drafts, + char *err, size_t errlen) { + if (!s || !out_drafts || max_drafts <= 0) { + if (errlen) snprintf(err, errlen, "mtp_drafts: null arg or zero max"); + return 0; + } + ds4_engine *e = s->engine; + if (!e->mtp_ready || !s->mtp_logits) { + if (errlen) snprintf(err, errlen, "mtp_drafts: MTP not loaded on this engine"); + return 0; + } + if (!s->mtp_draft_valid) { + /* No pending draft -- prior eval skipped MTP probing. */ + return 0; + } + + int draft_n = 0; + out_drafts[draft_n++] = (uint32_t)s->mtp_draft_token; + s->mtp_draft_valid = false; + + int cap = e->mtp_draft_tokens; + if (cap > max_drafts) cap = max_drafts; + if (cap > 16) cap = 16; + const int room = s->ctx_size - s->checkpoint.len; + if (cap > room - 1) cap = room - 1; + if (cap <= 1) return draft_n; + + /* Record the cache base so a later trim can compute the new mtp_n_raw + * from "drafts in this round" minus "drafts accepted". */ + s->graph.mtp_draft_round_base_raw = s->graph.mtp_n_raw - 1u; /* the eval just added 1 */ + + for (; draft_n < cap; draft_n++) { + ds4_metal_tensor *prev_hc = (draft_n & 1) ? s->graph.mtp_state_hc : s->graph.mtp_next_hc; + ds4_metal_tensor *out_hc = (draft_n & 1) ? s->graph.mtp_next_hc : s->graph.mtp_state_hc; + int mtp_top = -1; + if (!metal_graph_eval_mtp_draft_from_hc(&s->graph, + &e->model, &e->weights, + &e->mtp_model, &e->mtp_weights, + prev_hc, out_hc, + (int)out_drafts[draft_n - 1], + (uint32_t)(s->checkpoint.len + draft_n - 1), + /* logits = */ NULL, + &mtp_top)) + { + break; + } + out_drafts[draft_n] = (uint32_t)(mtp_top >= 0 ? mtp_top : 0); + } + s->graph.mtp_draft_round_n = (uint32_t)draft_n; + return draft_n; +} + +void ds4_session_mtp_trim_drafts(ds4_session *s, uint32_t keep_drafts) { + if (!s || !s->engine->mtp_ready) return; + /* The worker captured mtp_n_raw before drafting (mtp_draft_round_base_raw) + * and the count of drafts produced (mtp_draft_round_n). Truncate the + * cache to base + keep, clamped to round_n so callers can't grow it. */ + uint32_t keep = keep_drafts; + if (keep > s->graph.mtp_draft_round_n) keep = s->graph.mtp_draft_round_n; + uint32_t new_n_raw = s->graph.mtp_draft_round_base_raw + keep; + if (new_n_raw > s->graph.raw_window) new_n_raw = s->graph.raw_window; + s->graph.mtp_n_raw = new_n_raw; + s->mtp_draft_valid = false; +} +#endif + /* Speculative decode state machine: * 1. commit the normal target token and use its logits to validate draft[0]; * 2. let MTP recursively draft a tiny suffix from its own raw-cache frontier; @@ -16187,12 +17700,218 @@ int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token, * several proposed positions together; running ordinary decode once per * draft token is correctness-safe but cannot be faster than baseline. */ + const bool _sa_dbg = getenv("DS4_RPC_SPEC_DEBUG") != NULL && e && e->rpc_peer; + if (_sa_dbg) fprintf(stderr, "ds4-spec: SA: before eval first_token=%d\n", first_token); if (ds4_session_eval(s, first_token, err, errlen) != 0) return -1; + if (_sa_dbg) fprintf(stderr, "ds4-spec: SA: after eval; e=%p mtp_ready=%d mtp_draft_valid=%d mtp_draft_tokens=%d\n", + (void*)e, (int)e->mtp_ready, (int)s->mtp_draft_valid, e->mtp_draft_tokens); int n_accept = 0; accepted[n_accept++] = first_token; + if (_sa_dbg) fprintf(stderr, "ds4-spec: SA: pushed accepted[0]=%d; checking early returns\n", first_token); if (first_token == eos_token || max_tokens == 1 || n_accept >= accepted_cap) return n_accept; + if (_sa_dbg) fprintf(stderr, "ds4-spec: SA: past early returns\n"); + + /* Pipeline-parallel head: tail produces MTP drafts post-decode; we + * learned them via DECODE_REPLY into s->mtp_draft_token (draft 0) and + * s->rpc_extra_drafts (drafts 1..N-1). + * + * Verification is all-or-nothing under RPC: we check drafts[0] locally + * for free against s->logits (the just-evaluated target prediction); + * if that misses we commit only first_token. If drafts[0] hits we + * batch-verify drafts[0..N-1] in a single VERIFY_BATCH RPC: head runs + * its layers' prefill for those N tokens, exports the resulting + * batch_cur_hc, snapshots its KV state, ships the batch to the tail + * which runs its own prefill + per-row argmax against expected drafts. + * Tail returns accepted = N (full) or 0 (any miss, KV restored). Head + * mirrors: full accept commits all drafts and adopts the returned + * logits as s->logits; miss restores its KV snapshot and falls through + * to commit only first_token. Verification of the cycle takes one RPC + * roundtrip instead of N, and the tail amortizes layer setup across + * the batch -- the same speedup pattern as the single-host batched + * verifiers. */ + /* Phase 6 prefetch already consumed the prediction in + * ds4_session_eval_internal (it ran L0-L20 on drafts[0] and shipped + * before this function was even called). Running the batched-verify + * path here would re-snapshot the same spec_* buffers and re-do the + * head's L0-L20 work, producing redundant load and clobbering the + * pending in-flight reply. Disable batched-verify when prefetch is + * on; the prefetch hit/miss path provides the speedup. */ + const bool rpc_prefetch_on = (e->rpc_peer && e->rpc_tail_has_mtp && + getenv("DS4_RPC_PREFETCH") != NULL); + if (!rpc_prefetch_on && + e->rpc_peer && e->rpc_tail_has_mtp && s->mtp_draft_valid && + e->mtp_draft_tokens > 1) + { + int draft_n = (int)s->rpc_n_extra_drafts + 1; /* +1 for draft 0 */ + if (draft_n > 16) draft_n = 16; + int draft_cap = e->mtp_draft_tokens; + if (draft_cap > max_tokens - n_accept) draft_cap = max_tokens - n_accept; + if (draft_cap > accepted_cap - n_accept) draft_cap = accepted_cap - n_accept; + const int room = s->ctx_size - s->checkpoint.len; + if (draft_cap > room) draft_cap = room; + if (draft_n > draft_cap) draft_n = draft_cap; + + int drafts[16]; + drafts[0] = s->mtp_draft_token; + for (int i = 1; i < draft_n; i++) { + drafts[i] = (int)s->rpc_extra_drafts[i - 1]; + } + s->mtp_draft_valid = false; - if (!e->mtp_ready || !s->mtp_draft_valid || e->mtp_draft_tokens <= 1) return n_accept; + const bool spec_log = getenv("DS4_MTP_SPEC_LOG") != NULL; + + /* Free local verification of drafts[0]: argmax of the prediction we + * already have. Cheap miss path -- no RPC needed. */ + if (sample_argmax(s->logits, DS4_N_VOCAB) != drafts[0]) { + if (spec_log) { + fprintf(stderr, "ds4: rpc-mtp miss draft0=%d, no batch sent\n", drafts[0]); + } + char trim_err[160] = {0}; + (void)ds4_rpc_mtp_trim(e->rpc_peer, 0, trim_err, sizeof(trim_err)); + return n_accept; + } + + /* Snapshot head's KV state before running drafts through head's + * layers; on any tail miss we restore and bail out. */ + ds4_spec_frontier head_frontier; + memset(&head_frontier, 0, sizeof(head_frontier)); + if (!spec_frontier_snapshot(&head_frontier, s)) { + snprintf(err, errlen, "rpc-mtp: head spec_frontier_snapshot failed"); + return -1; + } + const uint32_t pre_pos = (uint32_t)s->checkpoint.len; + + /* Run head's slice as a batched prefill over the drafts, starting + * at the current session position. metal_graph_prefill_chunked_range + * reads prompt->v[start..start+n) and writes KV at the same absolute + * positions, so the prompt vec needs to be padded with zeros up to + * pre_pos and then the drafts. */ + const uint64_t batch_prompt_len = (uint64_t)pre_pos + (uint64_t)draft_n; + if (batch_prompt_len > INT_MAX) { + spec_frontier_restore(&head_frontier, s); + spec_frontier_free(&head_frontier); + snprintf(err, errlen, "rpc-mtp: pos+draft overflows int"); + return -1; + } + int *batch_tokens = (int *)calloc((size_t)batch_prompt_len, sizeof(int)); + if (!batch_tokens) { + spec_frontier_restore(&head_frontier, s); + spec_frontier_free(&head_frontier); + snprintf(err, errlen, "rpc-mtp: batch prompt alloc failed"); + return -1; + } + for (int i = 0; i < draft_n; i++) { + batch_tokens[pre_pos + i] = drafts[i]; + } + token_vec batch_prompt = { + .v = batch_tokens, + .len = (int)batch_prompt_len, + .cap = (int)batch_prompt_len, + }; + bool ok = metal_graph_prefill_chunked_range(&s->graph, &e->model, &e->weights, + &batch_prompt, + pre_pos, (uint32_t)draft_n, + /* logits */ NULL, + false, NULL, NULL, NULL, NULL); + free(batch_tokens); + if (!ok) { + spec_frontier_restore(&head_frontier, s); + spec_frontier_free(&head_frontier); + snprintf(err, errlen, "rpc-mtp: head batch prefill failed"); + return -1; + } + + /* Export head's batch_cur_hc and ship to the tail along with the + * expected drafts[1..draft_n-1] for per-row verification. */ + const uint64_t per_token = ds4_residual_hc_floats(); + const uint64_t total_floats = (uint64_t)draft_n * per_token; + float *batch_buf = (float *)malloc((size_t)(total_floats * sizeof(float))); + if (!batch_buf) { + spec_frontier_restore(&head_frontier, s); + spec_frontier_free(&head_frontier); + snprintf(err, errlen, "rpc-mtp: batch buffer alloc failed"); + return -1; + } + if (ds4_metal_tensor_read(s->graph.batch_cur_hc, 0, batch_buf, + total_floats * sizeof(float)) == 0) { + free(batch_buf); + spec_frontier_restore(&head_frontier, s); + spec_frontier_free(&head_frontier); + snprintf(err, errlen, "rpc-mtp: read batch_cur_hc failed"); + return -1; + } + + uint32_t expected_buf[16] = {0}; + for (int i = 1; i < draft_n; i++) expected_buf[i - 1] = (uint32_t)drafts[i]; + + uint32_t n_accepted = 0; + char rpc_err[256] = {0}; + if (ds4_rpc_verify_batch_request(e->rpc_peer, + (uint32_t)draft_n, pre_pos, + batch_buf, total_floats, + draft_n > 1 ? expected_buf : NULL, + (uint32_t)(draft_n - 1), + &n_accepted, + s->logits, DS4_N_VOCAB, + rpc_err, sizeof(rpc_err)) != 0) { + free(batch_buf); + spec_frontier_restore(&head_frontier, s); + spec_frontier_free(&head_frontier); + ds4_rpc_close(e->rpc_peer); + e->rpc_peer = NULL; + snprintf(err, errlen, "rpc-mtp: verify_batch RPC failed: %s " + "(connection dropped)", rpc_err); + return -1; + } + free(batch_buf); + + if (n_accepted == 0) { + /* Tail rejected -- restore head's KV state so it matches the + * tail's (which the tail already restored from its own + * snapshot). */ + spec_frontier_restore(&head_frontier, s); + spec_frontier_free(&head_frontier); + char trim_err[160] = {0}; + (void)ds4_rpc_mtp_trim(e->rpc_peer, 0, trim_err, sizeof(trim_err)); + if (spec_log) { + fprintf(stderr, "ds4: rpc-mtp batch miss, drafted=%d accepted=0\n", draft_n); + } + return n_accept; + } + + /* Full accept: commit drafts and advance the checkpoint to match. */ + spec_frontier_free(&head_frontier); + for (int i = 0; i < draft_n; i++) { + if (n_accept >= accepted_cap) break; + accepted[n_accept++] = drafts[i]; + token_vec_push(&s->checkpoint, drafts[i]); + if (drafts[i] == eos_token) break; + if (n_accept >= max_tokens) break; + } + s->checkpoint_valid = true; + + char trim_err[160] = {0}; + (void)ds4_rpc_mtp_trim(e->rpc_peer, (uint32_t)draft_n, + trim_err, sizeof(trim_err)); + if (spec_log) { + fprintf(stderr, "ds4: rpc-mtp batch accepted, drafted=%d accepted=%d\n", + draft_n, draft_n); + } + return n_accept; + } + + if (_sa_dbg) fprintf(stderr, "ds4-spec: SA: past batched-verify gate; checking single-host MTP gate\n"); + /* Phase 6: under prefetch the head has ALREADY consumed the prediction by + * running its L0-L20 speculatively at position N+1; running the single- + * host MTP draft path here on top of that would re-draft from a corrupted + * state and (if mtp_ready locally too) crash inside the draft kernel. + * Skip the single-host path entirely when prefetch is on -- the prefetch + * hit/miss path is the speculation. */ + if (rpc_prefetch_on || !e->mtp_ready || !s->mtp_draft_valid || e->mtp_draft_tokens <= 1) { + if (_sa_dbg) fprintf(stderr, "ds4-spec: SA: returning n_accept=%d via single-host MTP gate (prefetch=%d)\n", + n_accept, (int)rpc_prefetch_on); + return n_accept; + } int draft_cap = e->mtp_draft_tokens; if (draft_cap > max_tokens - n_accept) draft_cap = max_tokens - n_accept; @@ -16753,17 +18472,69 @@ int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token, #endif } +/* Tell the connected tail worker to drop its session. Best-effort: if the + * RPC fails we log and continue so the head's local state remains valid; the + * caller will discover divergence on the next decode if it really matters. */ +static void rpc_propagate_reset(ds4_session *s, const char *origin) { + if (!s || !s->engine || !s->engine->rpc_peer) return; +#ifndef DS4_NO_METAL + /* Drain any in-flight Phase 6 speculative request before issuing RESET, + * otherwise the next frame on the wire is a DECODE_REPLY and the head + * misreads it as RESET_REPLY (the bug seen at request boundaries on + * /tmp/long-bench-prompt.txt). abort() drains the reply AND rewinds + * the tail; the RESET we're about to send then truncates everything + * else, so the rewind is wasted but harmless. */ + if (s->rpc_spec_in_flight) { + char err[256] = {0}; + (void)rpc_spec_abort(s, err, sizeof(err)); + } +#endif + char rpc_err[256] = {0}; + if (ds4_rpc_reset(s->engine->rpc_peer, rpc_err, sizeof(rpc_err)) != 0) { + fprintf(stderr, + "ds4: rpc reset from %s failed: %s; tail KV state may be stale\n", + origin, rpc_err); + } +} + void ds4_session_invalidate(ds4_session *s) { + if (!s) return; +#ifndef DS4_NO_METAL + if (s->rpc_spec_in_flight) { + char err[256] = {0}; + (void)rpc_spec_abort(s, err, sizeof(err)); + } +#endif s->checkpoint_valid = false; s->checkpoint.len = 0; s->mtp_draft_valid = false; + rpc_propagate_reset(s, "invalidate"); } void ds4_session_rewind(ds4_session *s, int pos) { + if (!s) return; if (pos < 0) pos = 0; if (pos > s->checkpoint.len) pos = s->checkpoint.len; +#ifndef DS4_NO_METAL + if (s->rpc_spec_in_flight) { + char err[256] = {0}; + (void)rpc_spec_abort(s, err, sizeof(err)); + } +#endif s->checkpoint.len = pos; s->mtp_draft_valid = false; + /* Under RPC, tell the tail to truncate its checkpoint to the same + * position. Best-effort: on failure the next decode will surface a + * pos mismatch and the caller can re-prefill from scratch. */ + if (s->engine && s->engine->rpc_peer) { + char rpc_err[256] = {0}; + if (ds4_rpc_rewind(s->engine->rpc_peer, (uint32_t)pos, + rpc_err, sizeof(rpc_err)) != 0) { + fprintf(stderr, + "ds4: rpc rewind to pos=%d failed: %s; tail state may be stale\n", + pos, rpc_err); + } + } } int ds4_session_pos(ds4_session *s) { @@ -16773,3 +18544,7 @@ int ds4_session_pos(ds4_session *s) { int ds4_session_ctx(ds4_session *s) { return s->ctx_size; } + +uint32_t ds4_session_prefill_cap(ds4_session *s) { + return s ? s->prefill_cap : 0u; +} diff --git a/ds4.h b/ds4.h index 1e5ad66b..1327fe89 100644 --- a/ds4.h +++ b/ds4.h @@ -63,6 +63,27 @@ typedef struct { float mtp_margin; bool warm_weights; bool quality; + /* Pipeline-parallel layer range. Defaults (0, 0) select the full model + * [0, DS4_N_LAYER): n_layer_end <= 0 normalizes to DS4_N_LAYER. Partition + * the model across a head process (e.g. [0, 22)) and a tail RPC worker + * (e.g. [22, 43)) by setting these. Single-host operation uses the + * default and remains bit-identical to pre-RPC behavior. */ + int n_layer_start; + int n_layer_end; + + /* Optional RPC peer. When rpc_peer_host is set, the engine acts as the + * head: after this engine's local slice runs (layers [n_layer_start, + * n_layer_end)), the residual is shipped to the configured peer over + * TCP, and the peer's logits reply replaces local sampling input. The + * peer must run ds4-rpc-worker with --layer-start equal to n_layer_end + * and --layer-end DS4_N_LAYER. Unused on the worker side. When unset + * the engine is fully self-sufficient. */ + const char *rpc_peer_host; + int rpc_peer_port; + /* Context size advertised in the RPC handshake. Must match the worker's + * --ctx so the tail's KV cache is sized for the same window. Required + * when rpc_peer_host is set; ignored otherwise. */ + int rpc_ctx_size; } ds4_engine_options; typedef void (*ds4_token_emit_fn)(void *ud, int token); @@ -151,6 +172,26 @@ int ds4_session_argmax(ds4_session *s); int ds4_session_sample(ds4_session *s, float temperature, int top_k, float top_p, float min_p, uint64_t *rng); int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k); int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen); +/* Eval one token without probing MTP for drafts. Used by the RPC head + * during speculative verification, where the head ships want_drafts=0 so + * the tail doesn't waste MTP cycles on tokens that exist solely to verify + * a previously-received draft. */ +int ds4_session_eval_no_draft(ds4_session *s, int token, char *err, size_t errlen); + +/* Produce up to max_drafts MTP draft tokens by recursive drafting against + * the current MTP cache state. Returns the number of drafts written into + * out_drafts (0..max_drafts). Requires that an immediately-preceding + * ds4_session_eval has populated s->mtp_draft_token; that becomes drafts[0] + * and the function fills the rest. Worker uses this to ship a full draft + * batch in DECODE_REPLY when the head asks for them. */ +int ds4_session_mtp_drafts_after_eval(ds4_session *s, + uint32_t *out_drafts, int max_drafts, + char *err, size_t errlen); + +/* Reset the MTP cache row counter so the next decode starts MTP drafting + * from a clean state. Used on the tail when the head sends OP_MTP_TRIM + * with accepted_drafts == 0 (no drafts kept). */ +void ds4_session_mtp_trim_drafts(ds4_session *s, uint32_t keep_drafts); int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token, int max_tokens, int eos_token, int *accepted, int accepted_cap, @@ -159,9 +200,11 @@ void ds4_session_invalidate(ds4_session *s); void ds4_session_rewind(ds4_session *s, int pos); int ds4_session_pos(ds4_session *s); int ds4_session_ctx(ds4_session *s); +uint32_t ds4_session_prefill_cap(ds4_session *s); int ds4_engine_routed_quant_bits(ds4_engine *e); bool ds4_engine_has_mtp(ds4_engine *e); int ds4_engine_mtp_draft_tokens(ds4_engine *e); +bool ds4_engine_has_rpc_peer(ds4_engine *e); const ds4_tokens *ds4_session_tokens(ds4_session *s); /* Disk KV cache payload helpers. The server owns the outer file header and @@ -170,4 +213,91 @@ uint64_t ds4_session_payload_bytes(ds4_session *s); int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen); int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, char *err, size_t errlen); +/* Pipeline-parallel residual stream transfer (the cur_hc tensor: hyper- + * connection slots times embedding width, in floats). The returned size is + * model-fixed. Export reads the residual that comes out of this engine's + * last owned layer; import installs an incoming residual ahead of the next + * forward pass. Both are no-ops for engines that own the full layer range, + * but the API works in all cases so single-host and pipeline code paths can + * share call sites. Metal backend only; CPU returns an error string. */ +uint64_t ds4_residual_hc_floats(void); +int ds4_session_export_residual_hc(ds4_session *s, float *out, uint64_t n_floats, + char *err, size_t errlen); +int ds4_session_import_residual_hc(ds4_session *s, const float *in, uint64_t n_floats, + char *err, size_t errlen); + +/* Batched versions used by pipeline-parallel prefill. The head exports its + * batch_cur_hc tensor after running its layers' prefill on a chunk; the + * tail runs ds4_session_eval_batch_imported_hc to install the residual, + * run its own layers (and on the final chunk, the output projection), and + * advance its session position by n_tokens. pos_start must equal the + * tail's current ds4_session_pos. */ +int ds4_session_export_batch_residual_hc(ds4_session *s, float *out, + uint64_t n_tokens, char *err, size_t errlen); +int ds4_session_eval_batch_imported_hc(ds4_session *s, const float *in, + uint64_t n_tokens, uint32_t pos_start, + bool want_logits, char *err, size_t errlen); + +/* Batched all-or-nothing speculative verification on the tail. Imports + * n_tokens residuals into batch_cur_hc, snapshots KV state, runs prefill + * across all rows, then for each row i in [0, n_expected) compares the + * argmax of that row's output to expected_next[i]. If all match, advances + * the session checkpoint by n_tokens, writes final-row logits, and returns + * *out_n_accepted = n_tokens. If any mismatch (or any error), restores the + * snapshot and returns *out_n_accepted = 0 leaving the session untouched. */ +int ds4_session_verify_batch_imported_hc(ds4_session *s, + const float *batch_residual, + uint64_t n_tokens, uint32_t pos_start, + const uint32_t *expected_next, + uint32_t n_expected, + uint32_t *out_n_accepted, + float *final_logits, uint64_t n_logit_floats, + char *err, size_t errlen); + +/* Shape accessors so the RPC transport and worker don't have to duplicate + * DS4_N_* compile-time constants in their own code. All four values are + * model-fixed today; if a future DS4 variant changes them they will become + * engine-derived and these accessors will keep callers stable. */ +uint32_t ds4_model_n_layer(void); +uint32_t ds4_model_n_embd(void); +uint32_t ds4_model_n_hc(void); +uint32_t ds4_model_n_vocab(void); + +/* Resolve the GGUF model path for a CLI/server launch. Priority: + * 1. If explicit_path is non-NULL/non-empty, return it unchanged. + * 2. Else if quant ("q2" or "q4") is non-NULL, return the canonical + * path for that quant. If the file is missing, fail. + * 3. Else probe ./gguf/ for the canonical Q2 and Q4 files. If both + * are present, prefer Q2 (the smaller / safer-for-128GB default). + * If only one is present, return it. If neither, return + * "ds4flash.gguf" so the historical symlink mechanism still works. + * + * Returned pointer is either the caller's explicit_path or static + * storage; do not free. Returns NULL on error with a human-readable + * reason in err. */ +const char *ds4_resolve_model_path(const char *explicit_path, + const char *quant, + char *err, size_t errlen); + +/* Resolve the MTP support GGUF path. Priority: + * 1. explicit_path non-NULL/non-empty and not the sentinel "auto" -> return it. + * 2. Probe ./gguf/ for the canonical MTP file (the one fetched by + * ./download_model.sh mtp). If present, return it. + * 3. Return NULL. + * + * Returned pointer is either the caller's explicit_path or static storage; + * do not free. On miss (no file found and no explicit path), err is left + * empty -- callers that require MTP should treat NULL as their own error. + * The sentinel "auto" lets command-line parsers accept "--mtp" with no + * argument and still go through this resolver. */ +const char *ds4_resolve_mtp_path(const char *explicit_path, + char *err, size_t errlen); + +/* Read-only view of the last logits produced by ds4_session_eval or the + * tail of ds4_session_sync. Length is ds4_model_n_vocab() floats. Returns + * NULL if the session has not yet produced logits. The RPC worker uses + * this to ship a full logit vector back to the head without copying through + * top_logprobs. */ +const float *ds4_session_logits(const ds4_session *s); + #endif diff --git a/ds4_cli.c b/ds4_cli.c index 9851346d..2334a809 100644 --- a/ds4_cli.c +++ b/ds4_cli.c @@ -46,6 +46,7 @@ typedef struct { ds4_engine_options engine; cli_generation_options gen; char *prompt_owned; + const char *quant; /* --quant q2|q4, NULL if unspecified */ bool inspect; } cli_config; @@ -78,7 +79,14 @@ static void usage(FILE *fp) { "\n" "Model and runtime:\n" " -m, --model FILE\n" - " GGUF model path. Default: ds4flash.gguf\n" + " GGUF model path. Wins over --quant. Default: auto-detect Q2 or Q4\n" + " in ./gguf/, preferring Q2 if both are present; fall back to\n" + " ds4flash.gguf when neither is found.\n" + " --quant Q\n" + " Pick the canonical Q2 or Q4 file in ./gguf/ by name. Use 'q2'\n" + " for the 128 GB-friendly 86 GB model; use 'q4' only if you have\n" + " >=256 GB or are running pipeline-parallel with --rpc-peer.\n" + " Ignored when -m is also given.\n" " --mtp FILE\n" " Optional MTP support GGUF used for draft-token probes.\n" " --mtp-draft N\n" @@ -1154,7 +1162,7 @@ static char *read_prompt_file(const char *path, bool fatal) { static cli_config parse_options(int argc, char **argv) { cli_config c = { .engine = { - .model_path = "ds4flash.gguf", + .model_path = NULL, /* resolved after parsing via ds4_resolve_model_path */ .backend = DS4_BACKEND_METAL, .mtp_draft_tokens = 1, .mtp_margin = 3.0f, @@ -1193,8 +1201,17 @@ static cli_config parse_options(int argc, char **argv) { c.gen.system = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) { c.engine.model_path = need_arg(&i, argc, argv, arg); + } else if (!strcmp(arg, "--quant")) { + c.quant = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--mtp")) { - c.engine.mtp_path = need_arg(&i, argc, argv, arg); + /* Accept either "--mtp PATH" or bare "--mtp" (resolves to the + * canonical MTP GGUF in ./gguf/). */ + const char *next = (i + 1 < argc) ? argv[i + 1] : NULL; + if (next && next[0] && next[0] != '-') { + c.engine.mtp_path = next; i++; + } else { + c.engine.mtp_path = "auto"; + } } else if (!strcmp(arg, "--mtp-draft")) { c.engine.mtp_draft_tokens = parse_int(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--mtp-margin")) { @@ -1251,6 +1268,26 @@ static cli_config parse_options(int argc, char **argv) { c.inspect = true; } else if (!strcmp(arg, "--warm-weights")) { c.engine.warm_weights = true; + } else if (!strcmp(arg, "--rpc-peer")) { + /* Pipeline-parallel: ship the second half of layers to a tail + * worker reachable at host:port (default port 46434). Use with + * --rpc-split to set the boundary. */ + const char *spec = need_arg(&i, argc, argv, arg); + const char *colon = strrchr(spec, ':'); + if (colon && colon != spec) { + size_t host_len = (size_t)(colon - spec); + char *host = (char *)malloc(host_len + 1); + if (!host) { fprintf(stderr, "ds4: oom\n"); exit(2); } + memcpy(host, spec, host_len); + host[host_len] = '\0'; + c.engine.rpc_peer_host = host; + c.engine.rpc_peer_port = (int)strtol(colon + 1, NULL, 10); + } else { + c.engine.rpc_peer_host = spec; + c.engine.rpc_peer_port = 46434; + } + } else if (!strcmp(arg, "--rpc-split")) { + c.engine.n_layer_end = parse_int(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--server")) { fprintf(stderr, "ds4: use ds4-server for the HTTP server\n"); exit(2); @@ -1261,6 +1298,34 @@ static cli_config parse_options(int argc, char **argv) { } } + /* Final model-path resolution. Priority: -m wins, then --quant, then + * filesystem probe of the canonical Q2/Q4 paths (preferring Q2), then + * the historical ds4flash.gguf symlink fallback. */ + { + char resolve_err[256] = {0}; + const char *resolved = ds4_resolve_model_path(c.engine.model_path, + c.quant, + resolve_err, sizeof(resolve_err)); + if (!resolved) { + fprintf(stderr, "ds4: %s\n", resolve_err); + exit(2); + } + if (resolve_err[0]) fprintf(stderr, "ds4: %s\n", resolve_err); + c.engine.model_path = resolved; + } + if (c.engine.mtp_path) { + char resolve_err[256] = {0}; + const char *resolved = ds4_resolve_mtp_path(c.engine.mtp_path, + resolve_err, sizeof(resolve_err)); + if (!resolved) { + fprintf(stderr, "ds4: %s\n", + resolve_err[0] ? resolve_err : + "--mtp requested but no MTP GGUF found in ./gguf/"); + exit(2); + } + c.engine.mtp_path = resolved; + } + return c; } @@ -1282,6 +1347,9 @@ int main(int argc, char **argv) { log_context_memory(cfg.engine.backend, cfg.gen.ctx_size); cli_warn_think_max_downgraded(&cfg.gen, "--think-max"); } + /* Propagate --ctx into the engine so the RPC handshake can assert + * head and tail agree on KV window size. No-op for single-host. */ + cfg.engine.rpc_ctx_size = cfg.gen.ctx_size; ds4_engine *engine = NULL; if (ds4_engine_open(&engine, &cfg.engine) != 0) { free(cfg.prompt_owned); diff --git a/ds4_rpc.c b/ds4_rpc.c new file mode 100644 index 00000000..302f9028 --- /dev/null +++ b/ds4_rpc.c @@ -0,0 +1,1289 @@ +/* Pipeline-parallel RPC transport for ds4. See ds4_rpc.h for protocol + * overview. Plain blocking sockets, one connection per pair, framing is + * length-prefixed. Errors fail the whole connection rather than retrying; + * the caller (the head session) is expected to surface a clear failure to + * the user, who will restart the worker. Reconnect/keepalive is a Phase 5 + * concern. */ + +#define _POSIX_C_SOURCE 200809L +#define _DARWIN_C_SOURCE +#define _BSD_SOURCE +#define _DEFAULT_SOURCE + +#include "ds4_rpc.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ds4_rpc_handle { + int fd; + int listen_fd; /* held until first accept, then closed; -1 otherwise */ +}; + +static void rpc_set_err(char *err, size_t errlen, const char *fmt, ...) { + if (!err || errlen == 0) return; + va_list ap; + va_start(ap, fmt); + vsnprintf(err, errlen, fmt, ap); + va_end(ap); +} + +/* Endianness helpers. The wire is always little-endian. On a likely-LE + * Mac/x86 host these compile down to nothing; the explicit form lets a + * future BE port work without surprises. */ +static void put_u32_le(uint8_t out[4], uint32_t v) { + out[0] = (uint8_t)v; + out[1] = (uint8_t)(v >> 8); + out[2] = (uint8_t)(v >> 16); + out[3] = (uint8_t)(v >> 24); +} +static void put_u64_le(uint8_t out[8], uint64_t v) { + for (int i = 0; i < 8; i++) out[i] = (uint8_t)(v >> (i * 8)); +} +static uint32_t get_u32_le(const uint8_t in[4]) { + return (uint32_t)in[0] | ((uint32_t)in[1] << 8) | + ((uint32_t)in[2] << 16) | ((uint32_t)in[3] << 24); +} +static uint64_t get_u64_le(const uint8_t in[8]) { + uint64_t v = 0; + for (int i = 0; i < 8; i++) v |= (uint64_t)in[i] << (i * 8); + return v; +} + +static int io_read_full(int fd, void *buf, size_t n, char *err, size_t errlen) { + uint8_t *p = (uint8_t *)buf; + size_t got = 0; + while (got < n) { + ssize_t r = read(fd, p + got, n - got); + if (r > 0) { got += (size_t)r; continue; } + if (r == 0) { + rpc_set_err(err, errlen, "rpc: peer closed mid-frame after %zu/%zu bytes", got, n); + return 1; + } + if (errno == EINTR) continue; + rpc_set_err(err, errlen, "rpc: read: %s", strerror(errno)); + return 1; + } + return 0; +} + +static int io_write_full(int fd, const void *buf, size_t n, char *err, size_t errlen) { + const uint8_t *p = (const uint8_t *)buf; + size_t sent = 0; + while (sent < n) { + ssize_t w = write(fd, p + sent, n - sent); + if (w > 0) { sent += (size_t)w; continue; } + if (w < 0 && errno == EINTR) continue; + rpc_set_err(err, errlen, "rpc: write: %s", strerror(errno)); + return 1; + } + return 0; +} + +/* Frame layout: u32 length (excluding self) | u8 op | u8 reserved | u16 reserved + * | payload bytes. The opcode-payload pairing is fixed per opcode and known + * to both sides, so we do not embed type tags. */ +#define RPC_FRAME_HDR_BYTES 8u + +static int frame_write(int fd, uint8_t op, const void *payload, uint32_t payload_bytes, + char *err, size_t errlen) { + uint8_t hdr[RPC_FRAME_HDR_BYTES]; + put_u32_le(hdr, payload_bytes + (RPC_FRAME_HDR_BYTES - 4u)); + hdr[4] = op; + hdr[5] = 0; + hdr[6] = 0; + hdr[7] = 0; + if (io_write_full(fd, hdr, sizeof(hdr), err, errlen)) return 1; + if (payload_bytes && io_write_full(fd, payload, payload_bytes, err, errlen)) return 1; + return 0; +} + +static int frame_read_header(int fd, uint8_t *out_op, uint32_t *out_payload_bytes, + char *err, size_t errlen) { + uint8_t hdr[RPC_FRAME_HDR_BYTES]; + if (io_read_full(fd, hdr, sizeof(hdr), err, errlen)) return 1; + const uint32_t frame_after_len = get_u32_le(hdr); + if (frame_after_len < (RPC_FRAME_HDR_BYTES - 4u)) { + rpc_set_err(err, errlen, "rpc: truncated frame header (%u bytes)", frame_after_len); + return 1; + } + *out_op = hdr[4]; + *out_payload_bytes = frame_after_len - (RPC_FRAME_HDR_BYTES - 4u); + return 0; +} + +/* Connection setup. */ + +static int set_low_latency(int fd) { + int one = 1; + (void)setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); + (void)setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one)); + /* 5-minute read timeout. Prefill of a 30k-token prompt can legitimately + * take several minutes on TB but never hours, so a stuck peer surfaces + * as EAGAIN/ETIMEDOUT instead of an indefinite hang. Tunable via + * DS4_RPC_RECV_TIMEOUT_SECS for debugging. */ + long recv_timeout = 300; + const char *env = getenv("DS4_RPC_RECV_TIMEOUT_SECS"); + if (env && env[0]) { + char *endp = NULL; + long v = strtol(env, &endp, 10); + if (endp != env && v > 0) recv_timeout = v; + } + struct timeval tv = { .tv_sec = recv_timeout, .tv_usec = 0 }; + (void)setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + return 0; +} + +int ds4_rpc_dial(const char *host, uint16_t port, + ds4_rpc_handle **out, char *err, size_t errlen) { + if (!host || !out) { + rpc_set_err(err, errlen, "rpc_dial: null arg"); + return 1; + } + *out = NULL; + + char port_str[16]; + snprintf(port_str, sizeof(port_str), "%u", (unsigned)port); + + struct addrinfo hints = {0}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + struct addrinfo *res = NULL; + int gai = getaddrinfo(host, port_str, &hints, &res); + if (gai != 0) { + rpc_set_err(err, errlen, "rpc: getaddrinfo(%s:%s): %s", host, port_str, gai_strerror(gai)); + return 1; + } + + int fd = -1; + for (struct addrinfo *ai = res; ai; ai = ai->ai_next) { + fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + if (fd < 0) continue; + if (connect(fd, ai->ai_addr, ai->ai_addrlen) == 0) break; + close(fd); + fd = -1; + } + freeaddrinfo(res); + if (fd < 0) { + rpc_set_err(err, errlen, "rpc: connect(%s:%u) failed", host, (unsigned)port); + return 1; + } + set_low_latency(fd); + + ds4_rpc_handle *h = (ds4_rpc_handle *)calloc(1, sizeof(*h)); + if (!h) { + close(fd); + rpc_set_err(err, errlen, "rpc: out of memory"); + return 1; + } + h->fd = fd; + h->listen_fd = -1; + *out = h; + return 0; +} + +int ds4_rpc_listen_one(const char *bind_host, uint16_t port, + ds4_rpc_handle **out, char *err, size_t errlen) { + if (!out) { + rpc_set_err(err, errlen, "rpc_listen: null out"); + return 1; + } + *out = NULL; + + char port_str[16]; + snprintf(port_str, sizeof(port_str), "%u", (unsigned)port); + + struct addrinfo hints = {0}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + struct addrinfo *res = NULL; + int gai = getaddrinfo(bind_host && bind_host[0] ? bind_host : NULL, + port_str, &hints, &res); + if (gai != 0) { + rpc_set_err(err, errlen, "rpc: getaddrinfo bind: %s", gai_strerror(gai)); + return 1; + } + + int listen_fd = -1; + for (struct addrinfo *ai = res; ai; ai = ai->ai_next) { + listen_fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + if (listen_fd < 0) continue; + int one = 1; + (void)setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + if (bind(listen_fd, ai->ai_addr, ai->ai_addrlen) == 0 && + listen(listen_fd, 1) == 0) break; + close(listen_fd); + listen_fd = -1; + } + freeaddrinfo(res); + if (listen_fd < 0) { + rpc_set_err(err, errlen, "rpc: bind/listen on :%u failed: %s", (unsigned)port, strerror(errno)); + return 1; + } + + fprintf(stderr, "ds4-rpc: listening on %s:%u, awaiting head\n", + bind_host && bind_host[0] ? bind_host : "*", (unsigned)port); + + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + int fd = accept(listen_fd, (struct sockaddr *)&peer, &peerlen); + if (fd < 0) { + rpc_set_err(err, errlen, "rpc: accept: %s", strerror(errno)); + close(listen_fd); + return 1; + } + set_low_latency(fd); + + char peer_host[256] = "?"; + char peer_port[32] = "?"; + (void)getnameinfo((struct sockaddr *)&peer, peerlen, + peer_host, sizeof(peer_host), + peer_port, sizeof(peer_port), + NI_NUMERICHOST | NI_NUMERICSERV); + fprintf(stderr, "ds4-rpc: accepted head from %s:%s\n", peer_host, peer_port); + + close(listen_fd); + + ds4_rpc_handle *h = (ds4_rpc_handle *)calloc(1, sizeof(*h)); + if (!h) { + close(fd); + rpc_set_err(err, errlen, "rpc: out of memory"); + return 1; + } + h->fd = fd; + h->listen_fd = -1; + *out = h; + return 0; +} + +void ds4_rpc_close(ds4_rpc_handle *h) { + if (!h) return; + if (h->fd >= 0) close(h->fd); + if (h->listen_fd >= 0) close(h->listen_fd); + free(h); +} + +int ds4_rpc_fd(const ds4_rpc_handle *h) { + return h ? h->fd : -1; +} + +/* Config (de)serialization. */ +#define RPC_CFG_BYTES (12u * 4u + 8u + 32u) + +static void pack_config(uint8_t out[RPC_CFG_BYTES], const ds4_rpc_config *c) { + uint8_t *p = out; + put_u32_le(p, c->version); p += 4; + put_u32_le(p, c->n_layer_total); p += 4; + put_u32_le(p, c->n_embd); p += 4; + put_u32_le(p, c->n_hc); p += 4; + put_u32_le(p, c->n_vocab); p += 4; + put_u32_le(p, c->routed_quant_bits); p += 4; + put_u32_le(p, c->tail_layer_start); p += 4; + put_u32_le(p, c->tail_layer_end); p += 4; + put_u32_le(p, c->ctx_size); p += 4; + put_u32_le(p, c->tail_has_mtp); p += 4; + put_u32_le(p, c->tail_mtp_draft_tokens); p += 4; + put_u32_le(p, c->reserved0); p += 4; + put_u64_le(p, c->model_file_bytes); p += 8; + memcpy(p, c->model_sample, 32); p += 32; + (void)p; +} + +static void unpack_config(ds4_rpc_config *c, const uint8_t in[RPC_CFG_BYTES]) { + const uint8_t *p = in; + c->version = get_u32_le(p); p += 4; + c->n_layer_total = get_u32_le(p); p += 4; + c->n_embd = get_u32_le(p); p += 4; + c->n_hc = get_u32_le(p); p += 4; + c->n_vocab = get_u32_le(p); p += 4; + c->routed_quant_bits = get_u32_le(p); p += 4; + c->tail_layer_start = get_u32_le(p); p += 4; + c->tail_layer_end = get_u32_le(p); p += 4; + c->ctx_size = get_u32_le(p); p += 4; + c->tail_has_mtp = get_u32_le(p); p += 4; + c->tail_mtp_draft_tokens = get_u32_le(p); p += 4; + c->reserved0 = get_u32_le(p); p += 4; + c->model_file_bytes = get_u64_le(p); p += 8; + memcpy(c->model_sample, p, 32); +} + +static bool configs_match(const ds4_rpc_config *a, const ds4_rpc_config *b) { + if (a->version != b->version) return false; + if (a->n_layer_total != b->n_layer_total) return false; + if (a->n_embd != b->n_embd) return false; + if (a->n_hc != b->n_hc) return false; + if (a->n_vocab != b->n_vocab) return false; + if (a->routed_quant_bits != b->routed_quant_bits) return false; + if (a->tail_layer_start != b->tail_layer_start) return false; + if (a->tail_layer_end != b->tail_layer_end) return false; + if (a->ctx_size != b->ctx_size) return false; + if (a->model_file_bytes != b->model_file_bytes) return false; + if (memcmp(a->model_sample, b->model_sample, 32) != 0) return false; + /* tail_has_mtp and tail_mtp_draft_tokens are tail-side capabilities the + * head learns from peer-returned config; not matched here. */ + return true; +} + +/* Magic preamble: both sides write "DRPC" + version once at handshake. The + * version is duplicated into the config so an old client connecting to a new + * server (or vice versa) gets a clear mismatch message rather than wedging on + * unexpected payload sizes. */ +static int write_magic(int fd, char *err, size_t errlen) { + uint8_t buf[8]; + put_u32_le(buf, DS4_RPC_MAGIC); + put_u32_le(buf + 4, DS4_RPC_VERSION); + return io_write_full(fd, buf, sizeof(buf), err, errlen); +} + +static int read_magic(int fd, char *err, size_t errlen) { + uint8_t buf[8]; + if (io_read_full(fd, buf, sizeof(buf), err, errlen)) return 1; + const uint32_t magic = get_u32_le(buf); + const uint32_t ver = get_u32_le(buf + 4); + if (magic != DS4_RPC_MAGIC) { + rpc_set_err(err, errlen, "rpc: bad magic %#x, expected %#x", magic, DS4_RPC_MAGIC); + return 1; + } + if (ver != DS4_RPC_VERSION) { + rpc_set_err(err, errlen, "rpc: protocol version mismatch %u vs %u", ver, DS4_RPC_VERSION); + return 1; + } + return 0; +} + +int ds4_rpc_handshake_client_peer(ds4_rpc_handle *h, + const ds4_rpc_config *cfg, + ds4_rpc_config *out_peer, + char *err, size_t errlen); + +int ds4_rpc_handshake_client(ds4_rpc_handle *h, const ds4_rpc_config *cfg, + char *err, size_t errlen) { + return ds4_rpc_handshake_client_peer(h, cfg, NULL, err, errlen); +} + +int ds4_rpc_handshake_client_peer(ds4_rpc_handle *h, + const ds4_rpc_config *cfg, + ds4_rpc_config *out_peer, + char *err, size_t errlen) { + if (!h || !cfg) { rpc_set_err(err, errlen, "handshake: null arg"); return 1; } + if (write_magic(h->fd, err, errlen)) return 1; + + uint8_t cbuf[RPC_CFG_BYTES]; + pack_config(cbuf, cfg); + if (frame_write(h->fd, DS4_RPC_OP_HELLO_CLIENT, cbuf, sizeof(cbuf), err, errlen)) return 1; + + uint8_t op = 0; + uint32_t payload_bytes = 0; + if (frame_read_header(h->fd,&op, &payload_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_HELLO_SERVER) { + rpc_set_err(err, errlen, "handshake: server replied op=%u, expected HELLO_SERVER", op); + return 1; + } + if (payload_bytes < 4u + RPC_CFG_BYTES) { + rpc_set_err(err, errlen, + "handshake: server reply too short (%u bytes, need >= %u)", + payload_bytes, 4u + (uint32_t)RPC_CFG_BYTES); + return 1; + } + uint8_t status_buf[4]; + if (io_read_full(h->fd,status_buf, sizeof(status_buf), err, errlen)) return 1; + const uint32_t status = get_u32_le(status_buf); + + uint8_t peer_buf[RPC_CFG_BYTES]; + if (io_read_full(h->fd,peer_buf, sizeof(peer_buf), err, errlen)) return 1; + if (out_peer) unpack_config(out_peer, peer_buf); + + const uint32_t msg_bytes = payload_bytes - 4u - (uint32_t)RPC_CFG_BYTES; + + if (status != 0) { + char *msg = NULL; + if (msg_bytes > 0 && msg_bytes < 4096) { + msg = (char *)malloc((size_t)msg_bytes + 1u); + if (msg) { + if (io_read_full(h->fd,msg, msg_bytes, err, errlen) == 0) { + msg[msg_bytes] = '\0'; + rpc_set_err(err, errlen, "rpc: server rejected handshake: %s", msg); + } + free(msg); + return 1; + } + } + uint8_t tmp[256]; + uint32_t remaining = msg_bytes; + while (remaining > 0) { + uint32_t chunk = remaining < sizeof(tmp) ? remaining : (uint32_t)sizeof(tmp); + if (io_read_full(h->fd,tmp, chunk, NULL, 0)) break; + remaining -= chunk; + } + rpc_set_err(err, errlen, "rpc: server rejected handshake (status=%u)", status); + return 1; + } + /* status == 0: no error_msg present, drain any leftover defensively. */ + if (msg_bytes > 0) { + uint8_t tmp[256]; + uint32_t remaining = msg_bytes; + while (remaining > 0) { + uint32_t chunk = remaining < sizeof(tmp) ? remaining : (uint32_t)sizeof(tmp); + if (io_read_full(h->fd,tmp, chunk, NULL, 0)) break; + remaining -= chunk; + } + } + return 0; +} + +int ds4_rpc_handshake_server(ds4_rpc_handle *h, const ds4_rpc_config *cfg, + ds4_rpc_config *peer, char *err, size_t errlen) { + if (!h || !cfg) { rpc_set_err(err, errlen, "handshake: null arg"); return 1; } + if (read_magic(h->fd, err, errlen)) return 1; + + uint8_t op = 0; + uint32_t payload_bytes = 0; + if (frame_read_header(h->fd,&op, &payload_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_HELLO_CLIENT) { + rpc_set_err(err, errlen, "handshake: client opened with op=%u, expected HELLO_CLIENT", op); + return 1; + } + if (payload_bytes != RPC_CFG_BYTES) { + rpc_set_err(err, errlen, "handshake: client config payload %u bytes, expected %u", + payload_bytes, RPC_CFG_BYTES); + return 1; + } + uint8_t cbuf[RPC_CFG_BYTES]; + if (io_read_full(h->fd,cbuf, sizeof(cbuf), err, errlen)) return 1; + ds4_rpc_config got = {0}; + unpack_config(&got, cbuf); + if (peer) *peer = got; + + /* Validate. We require an exact config match: same model layout, same + * quant, same split point. Mismatch is reported back to the head with a + * human-readable reason so the user sees the cause without a tcpdump. */ + char reject_msg[512] = {0}; + if (!configs_match(&got, cfg)) { + if (got.version != cfg->version) { + snprintf(reject_msg, sizeof(reject_msg), + "version %u != %u", got.version, cfg->version); + } else if (got.n_layer_total != cfg->n_layer_total) { + snprintf(reject_msg, sizeof(reject_msg), + "n_layer_total %u != %u", got.n_layer_total, cfg->n_layer_total); + } else if (got.n_embd != cfg->n_embd || + got.n_hc != cfg->n_hc || + got.n_vocab != cfg->n_vocab) { + snprintf(reject_msg, sizeof(reject_msg), + "model shape mismatch (n_embd=%u/%u n_hc=%u/%u n_vocab=%u/%u)", + got.n_embd, cfg->n_embd, got.n_hc, cfg->n_hc, + got.n_vocab, cfg->n_vocab); + } else if (got.routed_quant_bits != cfg->routed_quant_bits) { + snprintf(reject_msg, sizeof(reject_msg), + "routed quant %u != %u (don't mix q2 and q4)", + got.routed_quant_bits, cfg->routed_quant_bits); + } else if (got.tail_layer_start != cfg->tail_layer_start || + got.tail_layer_end != cfg->tail_layer_end) { + snprintf(reject_msg, sizeof(reject_msg), + "split mismatch: head expects tail [%u, %u), worker owns [%u, %u)", + got.tail_layer_start, got.tail_layer_end, + cfg->tail_layer_start, cfg->tail_layer_end); + } else if (got.ctx_size != cfg->ctx_size) { + snprintf(reject_msg, sizeof(reject_msg), + "ctx mismatch: head ctx=%u, worker ctx=%u " + "(start the worker with --ctx %u to match)", + got.ctx_size, cfg->ctx_size, got.ctx_size); + } else if (got.model_file_bytes != cfg->model_file_bytes || + memcmp(got.model_sample, cfg->model_sample, 32) != 0) { + snprintf(reject_msg, sizeof(reject_msg), + "model fingerprint mismatch (file size or header bytes differ)"); + } else { + snprintf(reject_msg, sizeof(reject_msg), "unknown handshake mismatch"); + } + } + + const uint32_t status = reject_msg[0] ? 1u : 0u; + const uint32_t msg_bytes = (uint32_t)strlen(reject_msg); + const uint32_t reply_bytes = 4u + (uint32_t)RPC_CFG_BYTES + msg_bytes; + uint8_t *reply = (uint8_t *)malloc(reply_bytes); + if (!reply) { + rpc_set_err(err, errlen, "handshake: out of memory"); + return 1; + } + put_u32_le(reply, status); + pack_config(reply + 4, cfg); + if (msg_bytes) memcpy(reply + 4 + RPC_CFG_BYTES, reject_msg, msg_bytes); + int rc = frame_write(h->fd, DS4_RPC_OP_HELLO_SERVER, reply, reply_bytes, err, errlen); + free(reply); + if (rc) return 1; + + if (status != 0) { + rpc_set_err(err, errlen, "handshake rejected: %s", reject_msg); + return 1; + } + return 0; +} + +/* Decode request/reply. */ + +/* Decode reply header layout (16 bytes): + * u32 status + * u32 n_drafts (always present; 0 means no MTP attached) + * u64 n_logit_floats + * Then n_drafts * sizeof(u32) of draft tokens, then n_logit_floats * sizeof(float). + * Drafts come first because their length is a small fixed integer; logits come + * after so the receiver can stream-read them straight into the caller's buffer. + */ +int ds4_rpc_decode_send(ds4_rpc_handle *h, + uint32_t token, uint32_t pos, + bool want_drafts, + const float *residual_hc, uint64_t n_residual_floats, + char *err, size_t errlen) { + if (!h || !residual_hc) { + rpc_set_err(err, errlen, "decode_send: null arg"); + return 1; + } + const uint64_t residual_bytes = n_residual_floats * sizeof(float); + const uint64_t total = 4u + 4u + 4u + 4u + 8u + residual_bytes; + if (total > UINT32_MAX) { + rpc_set_err(err, errlen, "decode_send: residual too large"); + return 1; + } + uint8_t *buf = (uint8_t *)malloc((size_t)total); + if (!buf) { rpc_set_err(err, errlen, "decode_send: oom"); return 1; } + uint8_t *p = buf; + put_u32_le(p, token); p += 4; + put_u32_le(p, pos); p += 4; + put_u32_le(p, want_drafts ? 1u : 0u); p += 4; + put_u32_le(p, 0u); p += 4; /* reserved */ + put_u64_le(p, n_residual_floats); p += 8; + memcpy(p, residual_hc, (size_t)residual_bytes); + int rc = frame_write(h->fd, DS4_RPC_OP_DECODE_REQ, buf, (uint32_t)total, err, errlen); + free(buf); + return rc; +} + +int ds4_rpc_decode_recv_reply(ds4_rpc_handle *h, + float *out_logits, uint64_t n_logit_floats, + uint32_t *out_drafts, uint32_t max_drafts, + uint32_t *out_n_drafts, + char *err, size_t errlen) { + if (!h || !out_logits) { + rpc_set_err(err, errlen, "decode_recv_reply: null arg"); + return 1; + } + if (out_n_drafts) *out_n_drafts = 0; + + uint8_t op = 0; + uint32_t reply_bytes = 0; + if (frame_read_header(h->fd,&op, &reply_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_DECODE_REPLY) { + rpc_set_err(err, errlen, "decode_recv_reply: expected DECODE_REPLY, got op=%u", op); + return 1; + } + if (reply_bytes < 16u) { + rpc_set_err(err, errlen, "decode_recv_reply: reply header truncated (%u bytes)", reply_bytes); + return 1; + } + uint8_t hdr[16]; + if (io_read_full(h->fd,hdr, sizeof(hdr), err, errlen)) return 1; + const uint32_t status = get_u32_le(hdr); + const uint32_t n_drafts = get_u32_le(hdr + 4); + const uint64_t got_floats = get_u64_le(hdr + 8); + const uint64_t expect_payload = + (uint64_t)n_drafts * sizeof(uint32_t) + got_floats * sizeof(float); + if (reply_bytes - 16u != expect_payload) { + rpc_set_err(err, errlen, + "decode_recv_reply: reply payload size %u, expected %llu " + "(n_drafts=%u, n_logits=%llu)", + reply_bytes - 16u, (unsigned long long)expect_payload, + n_drafts, (unsigned long long)got_floats); + return 1; + } + if (status != 0) { + uint8_t tmp[4096]; + uint64_t remaining = expect_payload; + while (remaining > 0) { + uint64_t chunk = remaining < sizeof(tmp) ? remaining : sizeof(tmp); + if (io_read_full(h->fd,tmp, (size_t)chunk, NULL, 0)) break; + remaining -= chunk; + } + rpc_set_err(err, errlen, "decode_recv_reply: tail returned error status %u", status); + return 1; + } + if (got_floats != n_logit_floats) { + rpc_set_err(err, errlen, "decode_recv_reply: tail sent %llu floats, expected %llu", + (unsigned long long)got_floats, (unsigned long long)n_logit_floats); + return 1; + } + if (n_drafts > 0) { + const uint64_t draft_bytes = (uint64_t)n_drafts * sizeof(uint32_t); + if (out_drafts && max_drafts > 0) { + const uint32_t accept = n_drafts < max_drafts ? n_drafts : max_drafts; + uint8_t drafts_buf[DS4_RPC_MAX_DRAFTS * sizeof(uint32_t)]; + if (draft_bytes > sizeof(drafts_buf)) { + rpc_set_err(err, errlen, "decode_recv_reply: too many drafts (%u)", n_drafts); + return 1; + } + if (io_read_full(h->fd,drafts_buf, (size_t)draft_bytes, err, errlen)) return 1; + for (uint32_t i = 0; i < accept; i++) { + out_drafts[i] = get_u32_le(drafts_buf + i * 4); + } + if (out_n_drafts) *out_n_drafts = accept; + } else { + uint8_t tmp[DS4_RPC_MAX_DRAFTS * sizeof(uint32_t)]; + if (draft_bytes > sizeof(tmp)) { + rpc_set_err(err, errlen, "decode_recv_reply: too many drafts (%u) to drain", n_drafts); + return 1; + } + if (io_read_full(h->fd,tmp, (size_t)draft_bytes, err, errlen)) return 1; + } + } + return io_read_full(h->fd,out_logits, (size_t)(n_logit_floats * sizeof(float)), err, errlen); +} + +int ds4_rpc_decode_request(ds4_rpc_handle *h, + uint32_t token, uint32_t pos, + bool want_drafts, + const float *residual_hc, uint64_t n_residual_floats, + float *out_logits, uint64_t n_logit_floats, + uint32_t *out_drafts, uint32_t max_drafts, + uint32_t *out_n_drafts, + char *err, size_t errlen) { + if (ds4_rpc_decode_send(h, token, pos, want_drafts, + residual_hc, n_residual_floats, err, errlen) != 0) { + return 1; + } + return ds4_rpc_decode_recv_reply(h, out_logits, n_logit_floats, + out_drafts, max_drafts, out_n_drafts, + err, errlen); +} + +int ds4_rpc_decode_recv(ds4_rpc_handle *h, + uint32_t *token, uint32_t *pos, + bool *want_drafts, + float *residual_hc, uint64_t n_residual_floats, + char *err, size_t errlen) { + if (!h || !token || !pos || !want_drafts || !residual_hc) { + rpc_set_err(err, errlen, "decode_recv: null arg"); + return 1; + } + uint8_t op = 0; + uint32_t payload_bytes = 0; + if (frame_read_header(h->fd,&op, &payload_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_DECODE_REQ) { + rpc_set_err(err, errlen, "decode_recv: expected DECODE_REQ, got op=%u", op); + return 1; + } + const uint64_t expect = 4u + 4u + 4u + 4u + 8u + n_residual_floats * sizeof(float); + if (payload_bytes != expect) { + rpc_set_err(err, errlen, + "decode_recv: payload %u bytes, expected %llu", + payload_bytes, (unsigned long long)expect); + return 1; + } + uint8_t hdr[24]; + if (io_read_full(h->fd,hdr, sizeof(hdr), err, errlen)) return 1; + *token = get_u32_le(hdr); + *pos = get_u32_le(hdr + 4); + *want_drafts = get_u32_le(hdr + 8) != 0; + /* hdr+12..16 reserved */ + const uint64_t got_floats = get_u64_le(hdr + 16); + if (got_floats != n_residual_floats) { + rpc_set_err(err, errlen, + "decode_recv: residual size mismatch (got %llu, expected %llu)", + (unsigned long long)got_floats, (unsigned long long)n_residual_floats); + return 1; + } + return io_read_full(h->fd,residual_hc, + (size_t)(n_residual_floats * sizeof(float)), err, errlen); +} + +int ds4_rpc_decode_reply(ds4_rpc_handle *h, + const float *logits, uint64_t n_logit_floats, + const uint32_t *drafts, uint32_t n_drafts, + char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "decode_reply: null arg"); return 1; } + if (n_drafts > DS4_RPC_MAX_DRAFTS) { + rpc_set_err(err, errlen, "decode_reply: %u drafts exceeds max %u", + n_drafts, DS4_RPC_MAX_DRAFTS); + return 1; + } + if (n_drafts > 0 && !drafts) { + rpc_set_err(err, errlen, "decode_reply: n_drafts>0 but drafts is NULL"); + return 1; + } + const uint64_t logit_bytes = n_logit_floats * sizeof(float); + const uint64_t draft_bytes = (uint64_t)n_drafts * sizeof(uint32_t); + const uint64_t total = 4u + 4u + 8u + draft_bytes + logit_bytes; + if (total > UINT32_MAX) { + rpc_set_err(err, errlen, "decode_reply: payload too large"); + return 1; + } + uint8_t *buf = (uint8_t *)malloc((size_t)total); + if (!buf) { rpc_set_err(err, errlen, "decode_reply: oom"); return 1; } + uint8_t *p = buf; + put_u32_le(p, logits ? 0u : 1u); p += 4; /* status */ + put_u32_le(p, n_drafts); p += 4; + put_u64_le(p, n_logit_floats); p += 8; + for (uint32_t i = 0; i < n_drafts; i++) { + put_u32_le(p, drafts[i]); p += 4; + } + if (logits) memcpy(p, logits, (size_t)logit_bytes); + else memset(p, 0, (size_t)logit_bytes); + int rc = frame_write(h->fd, DS4_RPC_OP_DECODE_REPLY, buf, (uint32_t)total, err, errlen); + free(buf); + return rc; +} + +int ds4_rpc_mtp_trim(ds4_rpc_handle *h, uint32_t accepted_drafts, + char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "mtp_trim: null"); return 1; } + uint8_t payload[4]; + put_u32_le(payload, accepted_drafts); + if (frame_write(h->fd, DS4_RPC_OP_MTP_TRIM, payload, sizeof(payload), + err, errlen)) return 1; + uint8_t op = 0; + uint32_t bytes = 0; + if (frame_read_header(h->fd,&op, &bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_MTP_TRIM_REPLY || bytes != 0) { + rpc_set_err(err, errlen, "mtp_trim: unexpected reply op=%u bytes=%u", op, bytes); + return 1; + } + return 0; +} + +int ds4_rpc_mtp_trim_recv(ds4_rpc_handle *h, uint32_t *accepted_drafts, + char *err, size_t errlen) { + if (!h || !accepted_drafts) { + rpc_set_err(err, errlen, "mtp_trim_recv: null"); + return 1; + } + uint8_t op = 0; + uint32_t bytes = 0; + if (frame_read_header(h->fd,&op, &bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_MTP_TRIM || bytes != 4) { + rpc_set_err(err, errlen, "mtp_trim_recv: unexpected op=%u bytes=%u", op, bytes); + return 1; + } + uint8_t buf[4]; + if (io_read_full(h->fd,buf, sizeof(buf), err, errlen)) return 1; + *accepted_drafts = get_u32_le(buf); + return 0; +} + +int ds4_rpc_mtp_trim_reply(ds4_rpc_handle *h, char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "mtp_trim_reply: null"); return 1; } + return frame_write(h->fd, DS4_RPC_OP_MTP_TRIM_REPLY, NULL, 0, err, errlen); +} + +/* Verify-batch wire format. Request layout: + * u32 n_tokens + * u32 pos_start + * u32 n_expected (= n_tokens - 1; the drafts head wants verified) + * u32 reserved + * u64 n_residual_floats + * float32[n_residual_floats] batch_cur_hc rows + * u32[n_expected] expected_next tokens + * + * Reply layout: + * u32 status + * u32 n_accepted (0 = miss + KV reverted, n_tokens = full accept) + * u32 reserved + * u32 reserved + * u64 n_logit_floats + * float32[n_logit_floats] logits (only meaningful when n_accepted > 0) + */ +int ds4_rpc_verify_batch_request(ds4_rpc_handle *h, + uint32_t n_tokens, uint32_t pos_start, + const float *batch_residual, + uint64_t n_residual_floats, + const uint32_t *expected_next, + uint32_t n_expected, + uint32_t *out_n_accepted, + float *out_logits, uint64_t n_logit_floats, + char *err, size_t errlen) { + if (!h || !batch_residual || !out_n_accepted || !out_logits) { + rpc_set_err(err, errlen, "verify_batch_request: null arg"); + return 1; + } + if (n_expected > 0 && !expected_next) { + rpc_set_err(err, errlen, "verify_batch_request: n_expected>0 but null buffer"); + return 1; + } + *out_n_accepted = 0; + + const uint64_t residual_bytes = n_residual_floats * sizeof(float); + const uint64_t expected_bytes = (uint64_t)n_expected * sizeof(uint32_t); + const uint64_t total = 4u + 4u + 4u + 4u + 8u + residual_bytes + expected_bytes; + if (total > UINT32_MAX) { + rpc_set_err(err, errlen, "verify_batch_request: payload too large"); + return 1; + } + uint8_t *buf = (uint8_t *)malloc((size_t)total); + if (!buf) { rpc_set_err(err, errlen, "verify_batch_request: oom"); return 1; } + uint8_t *p = buf; + put_u32_le(p, n_tokens); p += 4; + put_u32_le(p, pos_start); p += 4; + put_u32_le(p, n_expected); p += 4; + put_u32_le(p, 0u); p += 4; + put_u64_le(p, n_residual_floats); p += 8; + memcpy(p, batch_residual, (size_t)residual_bytes); p += residual_bytes; + for (uint32_t i = 0; i < n_expected; i++) { + put_u32_le(p, expected_next[i]); p += 4; + } + int rc = frame_write(h->fd, DS4_RPC_OP_VERIFY_BATCH, buf, (uint32_t)total, + err, errlen); + free(buf); + if (rc) return 1; + + uint8_t op = 0; + uint32_t reply_bytes = 0; + if (frame_read_header(h->fd,&op, &reply_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_VERIFY_BATCH_REPLY) { + rpc_set_err(err, errlen, "verify_batch_request: expected reply op=%u, got %u", + DS4_RPC_OP_VERIFY_BATCH_REPLY, op); + return 1; + } + const uint64_t reply_min = 4u + 4u + 4u + 4u + 8u; + if (reply_bytes < reply_min) { + rpc_set_err(err, errlen, "verify_batch_request: reply truncated (%u bytes)", reply_bytes); + return 1; + } + uint8_t hdr[24]; + if (io_read_full(h->fd,hdr, sizeof(hdr), err, errlen)) return 1; + const uint32_t status = get_u32_le(hdr); + const uint32_t n_accepted = get_u32_le(hdr + 4); + const uint64_t got_floats = get_u64_le(hdr + 16); + const uint64_t remaining = reply_bytes - 24u; + + if (status != 0) { + uint8_t tmp[4096]; + uint64_t left = remaining; + while (left > 0) { + uint64_t chunk = left < sizeof(tmp) ? left : sizeof(tmp); + if (io_read_full(h->fd,tmp, (size_t)chunk, NULL, 0)) break; + left -= chunk; + } + rpc_set_err(err, errlen, "verify_batch_request: tail status %u", status); + return 1; + } + + *out_n_accepted = n_accepted; + if (remaining != got_floats * sizeof(float)) { + rpc_set_err(err, errlen, + "verify_batch_request: reply payload %llu, expected %llu", + (unsigned long long)remaining, + (unsigned long long)(got_floats * sizeof(float))); + return 1; + } + if (got_floats == 0) return 0; + if (got_floats != n_logit_floats) { + rpc_set_err(err, errlen, + "verify_batch_request: tail sent %llu floats, expected %llu", + (unsigned long long)got_floats, + (unsigned long long)n_logit_floats); + return 1; + } + return io_read_full(h->fd,out_logits, (size_t)remaining, err, errlen); +} + +int ds4_rpc_verify_batch_recv(ds4_rpc_handle *h, + uint32_t *n_tokens, uint32_t *pos_start, + float *batch_residual, uint64_t max_residual_floats, + uint64_t *out_n_residual_floats, + uint32_t *expected_next, uint32_t max_expected, + uint32_t *out_n_expected, + char *err, size_t errlen) { + if (!h || !n_tokens || !pos_start || !batch_residual || + !out_n_residual_floats || !expected_next || !out_n_expected) { + rpc_set_err(err, errlen, "verify_batch_recv: null arg"); + return 1; + } + uint8_t op = 0; + uint32_t payload_bytes = 0; + if (frame_read_header(h->fd,&op, &payload_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_VERIFY_BATCH) { + rpc_set_err(err, errlen, "verify_batch_recv: expected VERIFY_BATCH, got op=%u", op); + return 1; + } + if (payload_bytes < 24u) { + rpc_set_err(err, errlen, "verify_batch_recv: payload truncated (%u bytes)", payload_bytes); + return 1; + } + uint8_t hdr[24]; + if (io_read_full(h->fd,hdr, sizeof(hdr), err, errlen)) return 1; + *n_tokens = get_u32_le(hdr); + *pos_start = get_u32_le(hdr + 4); + const uint32_t n_expected_in = get_u32_le(hdr + 8); + const uint64_t n_floats = get_u64_le(hdr + 16); + const uint64_t expected_bytes = (uint64_t)n_expected_in * sizeof(uint32_t); + const uint64_t want_bytes = n_floats * sizeof(float); + if (payload_bytes - 24u != want_bytes + expected_bytes) { + rpc_set_err(err, errlen, + "verify_batch_recv: payload size mismatch (got %u, expected %llu + %llu)", + payload_bytes - 24u, + (unsigned long long)want_bytes, + (unsigned long long)expected_bytes); + return 1; + } + if (n_floats > max_residual_floats) { + rpc_set_err(err, errlen, + "verify_batch_recv: residual %llu floats exceeds buffer %llu", + (unsigned long long)n_floats, + (unsigned long long)max_residual_floats); + return 1; + } + if (n_expected_in > max_expected) { + rpc_set_err(err, errlen, + "verify_batch_recv: %u expected exceeds buffer %u", + n_expected_in, max_expected); + return 1; + } + *out_n_residual_floats = n_floats; + *out_n_expected = n_expected_in; + if (io_read_full(h->fd,batch_residual, (size_t)want_bytes, err, errlen)) return 1; + if (n_expected_in > 0) { + uint8_t tmp[16 * 4]; + if (expected_bytes > sizeof(tmp)) { + rpc_set_err(err, errlen, "verify_batch_recv: %u expected too many", + n_expected_in); + return 1; + } + if (io_read_full(h->fd,tmp, (size_t)expected_bytes, err, errlen)) return 1; + for (uint32_t i = 0; i < n_expected_in; i++) { + expected_next[i] = get_u32_le(tmp + i * 4); + } + } + return 0; +} + +int ds4_rpc_verify_batch_reply(ds4_rpc_handle *h, + uint32_t n_accepted, + const float *logits, uint64_t n_logit_floats, + char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "verify_batch_reply: null"); return 1; } + const bool has_logits = (logits != NULL && n_logit_floats > 0); + const uint64_t logit_bytes = has_logits ? n_logit_floats * sizeof(float) : 0u; + const uint64_t total = 4u + 4u + 4u + 4u + 8u + logit_bytes; + if (total > UINT32_MAX) { + rpc_set_err(err, errlen, "verify_batch_reply: payload too large"); + return 1; + } + uint8_t *buf = (uint8_t *)malloc((size_t)total); + if (!buf) { rpc_set_err(err, errlen, "verify_batch_reply: oom"); return 1; } + uint8_t *p = buf; + put_u32_le(p, 0u); p += 4; /* status */ + put_u32_le(p, n_accepted); p += 4; + put_u32_le(p, 0u); p += 4; + put_u32_le(p, 0u); p += 4; + put_u64_le(p, has_logits ? n_logit_floats : 0u); p += 8; + if (has_logits) memcpy(p, logits, (size_t)logit_bytes); + int rc = frame_write(h->fd, DS4_RPC_OP_VERIFY_BATCH_REPLY, buf, (uint32_t)total, + err, errlen); + free(buf); + return rc; +} + +/* Prefill request: one chunk of the prompt's batch_cur_hc, sized + * (n_tokens * DS4_N_HC * DS4_N_EMBD) floats. Reply is empty if + * !want_logits, otherwise carries one DS4_N_VOCAB-sized logits vector. + * Frame layout for the request: + * u32 n_tokens + * u32 pos_start + * u32 want_logits (0 or 1) + * u32 reserved + * u64 n_residual_floats + * float32[n_residual_floats] batch residual data + * Frame layout for the reply: + * u32 status (0 = ok, !=0 = error) + * u32 has_logits (0 or 1; on error always 0) + * u64 n_logit_floats (matches has_logits) + * float32[n_logit_floats] logits */ +int ds4_rpc_prefill_request(ds4_rpc_handle *h, + uint32_t n_tokens, uint32_t pos_start, + bool want_logits, + const float *batch_residual_hc, + uint64_t n_residual_floats, + float *out_logits, uint64_t n_logit_floats, + char *err, size_t errlen) { + if (!h || !batch_residual_hc) { + rpc_set_err(err, errlen, "prefill_request: null arg"); + return 1; + } + if (want_logits && (!out_logits || n_logit_floats == 0)) { + rpc_set_err(err, errlen, "prefill_request: want_logits set but no output buffer"); + return 1; + } + const uint64_t residual_bytes = n_residual_floats * sizeof(float); + const uint64_t payload_bytes = 4u + 4u + 4u + 4u + 8u + residual_bytes; + if (payload_bytes > UINT32_MAX) { + rpc_set_err(err, errlen, "prefill_request: chunk too large to frame"); + return 1; + } + + uint8_t *buf = (uint8_t *)malloc((size_t)payload_bytes); + if (!buf) { rpc_set_err(err, errlen, "prefill_request: oom"); return 1; } + uint8_t *p = buf; + put_u32_le(p, n_tokens); p += 4; + put_u32_le(p, pos_start); p += 4; + put_u32_le(p, want_logits ? 1u : 0u); p += 4; + put_u32_le(p, 0u); p += 4; + put_u64_le(p, n_residual_floats); p += 8; + memcpy(p, batch_residual_hc, (size_t)residual_bytes); + int rc = frame_write(h->fd, DS4_RPC_OP_PREFILL_REQ, buf, (uint32_t)payload_bytes, + err, errlen); + free(buf); + if (rc) return 1; + + uint8_t op = 0; + uint32_t reply_bytes = 0; + if (frame_read_header(h->fd,&op, &reply_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_PREFILL_REPLY) { + rpc_set_err(err, errlen, "prefill_request: expected PREFILL_REPLY, got op=%u", op); + return 1; + } + if (reply_bytes < 4u + 4u + 8u) { + rpc_set_err(err, errlen, "prefill_request: reply header truncated (%u bytes)", reply_bytes); + return 1; + } + uint8_t hdr[16]; + if (io_read_full(h->fd,hdr, sizeof(hdr), err, errlen)) return 1; + const uint32_t status = get_u32_le(hdr); + const uint32_t has_logits = get_u32_le(hdr + 4); + const uint64_t got_floats = get_u64_le(hdr + 8); + const uint64_t remaining = reply_bytes - 16u; + + if (status != 0) { + /* Drain any payload bytes so the connection stays in sync. */ + uint8_t tmp[4096]; + uint64_t left = remaining; + while (left > 0) { + uint64_t chunk = left < sizeof(tmp) ? left : sizeof(tmp); + if (io_read_full(h->fd,tmp, (size_t)chunk, NULL, 0)) break; + left -= chunk; + } + rpc_set_err(err, errlen, "prefill_request: tail returned error status %u", status); + return 1; + } + if (!want_logits) { + if (has_logits || got_floats != 0 || remaining != 0) { + rpc_set_err(err, errlen, "prefill_request: tail returned logits we did not request"); + return 1; + } + return 0; + } + if (!has_logits || got_floats != n_logit_floats) { + rpc_set_err(err, errlen, + "prefill_request: tail returned %llu floats (has=%u), expected %llu (has=1)", + (unsigned long long)got_floats, has_logits, + (unsigned long long)n_logit_floats); + return 1; + } + if (remaining != n_logit_floats * sizeof(float)) { + rpc_set_err(err, errlen, + "prefill_request: reply payload size mismatch (%llu vs %llu)", + (unsigned long long)remaining, + (unsigned long long)(n_logit_floats * sizeof(float))); + return 1; + } + return io_read_full(h->fd,out_logits, (size_t)remaining, err, errlen); +} + +int ds4_rpc_prefill_recv(ds4_rpc_handle *h, + uint32_t *n_tokens, uint32_t *pos_start, + bool *want_logits, + float *batch_residual_hc, uint64_t max_residual_floats, + uint64_t *out_n_residual_floats, + char *err, size_t errlen) { + if (!h || !n_tokens || !pos_start || !want_logits || !batch_residual_hc || + !out_n_residual_floats) { + rpc_set_err(err, errlen, "prefill_recv: null arg"); + return 1; + } + uint8_t op = 0; + uint32_t payload_bytes = 0; + if (frame_read_header(h->fd,&op, &payload_bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_PREFILL_REQ) { + rpc_set_err(err, errlen, "prefill_recv: expected PREFILL_REQ, got op=%u", op); + return 1; + } + if (payload_bytes < 4u + 4u + 4u + 4u + 8u) { + rpc_set_err(err, errlen, "prefill_recv: payload truncated (%u bytes)", payload_bytes); + return 1; + } + uint8_t hdr[24]; + if (io_read_full(h->fd,hdr, sizeof(hdr), err, errlen)) return 1; + *n_tokens = get_u32_le(hdr); + *pos_start = get_u32_le(hdr + 4); + *want_logits = get_u32_le(hdr + 8) != 0; + /* hdr+12..16 reserved */ + const uint64_t n_floats = get_u64_le(hdr + 16); + const uint64_t want_bytes = n_floats * sizeof(float); + if (payload_bytes - 24u != want_bytes) { + rpc_set_err(err, errlen, + "prefill_recv: residual size mismatch (header says %llu floats = %llu bytes, " + "frame has %llu bytes of payload after header)", + (unsigned long long)n_floats, + (unsigned long long)want_bytes, + (unsigned long long)(payload_bytes - 24u)); + return 1; + } + if (n_floats > max_residual_floats) { + rpc_set_err(err, errlen, + "prefill_recv: chunk needs %llu residual floats but caller buffer holds only %llu", + (unsigned long long)n_floats, + (unsigned long long)max_residual_floats); + return 1; + } + *out_n_residual_floats = n_floats; + return io_read_full(h->fd,batch_residual_hc, (size_t)want_bytes, err, errlen); +} + +int ds4_rpc_prefill_reply(ds4_rpc_handle *h, + bool has_logits, + const float *logits, uint64_t n_logit_floats, + char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "prefill_reply: null arg"); return 1; } + if (has_logits && (!logits || n_logit_floats == 0)) { + rpc_set_err(err, errlen, "prefill_reply: has_logits set but no logits provided"); + return 1; + } + const uint64_t logit_bytes = has_logits ? n_logit_floats * sizeof(float) : 0u; + const uint64_t payload_bytes = 4u + 4u + 8u + logit_bytes; + if (payload_bytes > UINT32_MAX) { + rpc_set_err(err, errlen, "prefill_reply: logits too large to frame"); + return 1; + } + uint8_t *buf = (uint8_t *)malloc((size_t)payload_bytes); + if (!buf) { rpc_set_err(err, errlen, "prefill_reply: oom"); return 1; } + uint8_t *p = buf; + put_u32_le(p, 0u); p += 4; /* status = ok */ + put_u32_le(p, has_logits ? 1u : 0u); p += 4; + put_u64_le(p, has_logits ? n_logit_floats : 0u); p += 8; + if (has_logits) memcpy(p, logits, (size_t)logit_bytes); + int rc = frame_write(h->fd, DS4_RPC_OP_PREFILL_REPLY, buf, (uint32_t)payload_bytes, + err, errlen); + free(buf); + return rc; +} + +int ds4_rpc_reset(ds4_rpc_handle *h, char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "reset: null"); return 1; } + if (frame_write(h->fd, DS4_RPC_OP_RESET, NULL, 0, err, errlen)) return 1; + uint8_t op = 0; + uint32_t bytes = 0; + if (frame_read_header(h->fd,&op, &bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_RESET_REPLY || bytes != 0) { + rpc_set_err(err, errlen, "reset: unexpected reply op=%u bytes=%u", op, bytes); + return 1; + } + return 0; +} + +/* Non-destructive peek of the next frame's opcode. Used by the tail worker + * to dispatch DECODE_REQ vs RESET vs SHUTDOWN. The op byte sits at offset 4 + * in the frame header; MSG_PEEK returns up to that without consuming bytes + * from the socket buffer, so the matched *_recv helper can read the full + * frame normally afterwards. */ +int ds4_rpc_recv_op(ds4_rpc_handle *h, ds4_rpc_op *op, + char *err, size_t errlen) { + if (!h || !op) { rpc_set_err(err, errlen, "recv_op: null"); return 1; } + uint8_t buf[RPC_FRAME_HDR_BYTES]; + size_t got = 0; + while (got < sizeof(buf)) { + ssize_t r = recv(h->fd, buf + got, sizeof(buf) - got, MSG_PEEK); + if (r > 0) { got = (size_t)r; continue; } + if (r == 0) { + rpc_set_err(err, errlen, "rpc: peer closed before next frame"); + return 1; + } + if (errno == EINTR) continue; + rpc_set_err(err, errlen, "rpc: peek: %s", strerror(errno)); + return 1; + } + *op = (ds4_rpc_op)buf[4]; + return 0; +} + +/* Consume a control frame whose payload is empty (RESET, SHUTDOWN). */ +static int consume_empty_frame(int fd, uint8_t expected_op, + char *err, size_t errlen) { + uint8_t op = 0; + uint32_t bytes = 0; + if (frame_read_header(fd, &op, &bytes, err, errlen)) return 1; + if (op != expected_op || bytes != 0) { + rpc_set_err(err, errlen, + "rpc: unexpected control frame op=%u bytes=%u (wanted op=%u, 0 bytes)", + op, bytes, expected_op); + return 1; + } + return 0; +} + +int ds4_rpc_reset_recv(ds4_rpc_handle *h, char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "reset_recv: null"); return 1; } + return consume_empty_frame(h->fd, DS4_RPC_OP_RESET, err, errlen); +} + +int ds4_rpc_shutdown_recv(ds4_rpc_handle *h, char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "shutdown_recv: null"); return 1; } + return consume_empty_frame(h->fd, DS4_RPC_OP_SHUTDOWN, err, errlen); +} + +int ds4_rpc_reset_reply(ds4_rpc_handle *h, char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "reset_reply: null"); return 1; } + return frame_write(h->fd, DS4_RPC_OP_RESET_REPLY, NULL, 0, err, errlen); +} + +int ds4_rpc_shutdown_send(ds4_rpc_handle *h) { + if (!h) return 1; + char err[64]; + return frame_write(h->fd, DS4_RPC_OP_SHUTDOWN, NULL, 0, err, sizeof(err)); +} + +int ds4_rpc_rewind(ds4_rpc_handle *h, uint32_t target_pos, + char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "rewind: null"); return 1; } + uint8_t payload[4]; + put_u32_le(payload, target_pos); + if (frame_write(h->fd, DS4_RPC_OP_REWIND, payload, sizeof(payload), + err, errlen)) return 1; + uint8_t op = 0; + uint32_t bytes = 0; + if (frame_read_header(h->fd,&op, &bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_REWIND_REPLY || bytes != 0) { + rpc_set_err(err, errlen, "rewind: unexpected reply op=%u bytes=%u", op, bytes); + return 1; + } + return 0; +} + +int ds4_rpc_rewind_recv(ds4_rpc_handle *h, uint32_t *target_pos, + char *err, size_t errlen) { + if (!h || !target_pos) { rpc_set_err(err, errlen, "rewind_recv: null"); return 1; } + uint8_t op = 0; + uint32_t bytes = 0; + if (frame_read_header(h->fd,&op, &bytes, err, errlen)) return 1; + if (op != DS4_RPC_OP_REWIND || bytes != 4) { + rpc_set_err(err, errlen, "rewind_recv: unexpected op=%u bytes=%u", op, bytes); + return 1; + } + uint8_t buf[4]; + if (io_read_full(h->fd,buf, sizeof(buf), err, errlen)) return 1; + *target_pos = get_u32_le(buf); + return 0; +} + +int ds4_rpc_rewind_reply(ds4_rpc_handle *h, char *err, size_t errlen) { + if (!h) { rpc_set_err(err, errlen, "rewind_reply: null"); return 1; } + return frame_write(h->fd, DS4_RPC_OP_REWIND_REPLY, NULL, 0, err, errlen); +} diff --git a/ds4_rpc.h b/ds4_rpc.h new file mode 100644 index 00000000..63d67ae9 --- /dev/null +++ b/ds4_rpc.h @@ -0,0 +1,225 @@ +#ifndef DS4_RPC_H +#define DS4_RPC_H + +#include +#include +#include + +#include "ds4.h" + +/* Pipeline-parallel RPC transport. + * + * The head process owns layers [0, L_mid), drives session state, samples + * tokens, and talks to one tail worker that owns [L_mid, DS4_N_LAYER) on a + * second machine. At each decode step the head ships (token, pos, residual) + * to the tail and reads back the full logits vector. Prefill, MTP, and the + * disk KV cache are not handled here yet; this layer is intentionally narrow + * so a first working multi-machine path can land before we expand surface. + * + * Wire format: little-endian, length-prefixed frames carrying one opcode and + * one fixed payload shape per opcode. The handshake validates that both + * sides agree on model shape, routed-quant bits, and the split point; a + * mismatch fails the connection rather than silently producing wrong output. + * + * This module is Metal-only by transitive dependency on the engine, but the + * transport itself is plain POSIX sockets and would work on Linux/FreeBSD if + * a CUDA/Vulkan engine ever needed it. */ + +#define DS4_RPC_MAGIC 0x43505244u /* "DRPC" little-endian */ +#define DS4_RPC_VERSION 1u + +typedef enum { + DS4_RPC_OP_HELLO_CLIENT = 1, /* head -> tail: config exchange */ + DS4_RPC_OP_HELLO_SERVER = 2, /* tail -> head: ack or error */ + DS4_RPC_OP_DECODE_REQ = 3, /* head -> tail: token + pos + residual */ + DS4_RPC_OP_DECODE_REPLY = 4, /* tail -> head: status + logits */ + DS4_RPC_OP_RESET = 5, /* head -> tail: drop session, fresh KV */ + DS4_RPC_OP_RESET_REPLY = 6, /* tail -> head: ack reset */ + DS4_RPC_OP_SHUTDOWN = 7, /* either side: clean disconnect */ + DS4_RPC_OP_PREFILL_REQ = 8, /* head -> tail: batch residual for a chunk */ + DS4_RPC_OP_PREFILL_REPLY = 9, /* tail -> head: status + maybe logits */ + DS4_RPC_OP_REWIND = 10, /* head -> tail: rewind to target_pos */ + DS4_RPC_OP_REWIND_REPLY = 11, /* tail -> head: ack */ + DS4_RPC_OP_MTP_TRIM = 12, /* head -> tail: trim MTP drafts to accepted */ + DS4_RPC_OP_MTP_TRIM_REPLY = 13, + DS4_RPC_OP_VERIFY_BATCH = 14, /* head -> tail: batched all-or-nothing verify */ + DS4_RPC_OP_VERIFY_BATCH_REPLY = 15, +} ds4_rpc_op; + +/* Configuration carried in the handshake. Both sides must agree on every + * field except role. The split point is the tail's range; the head's range + * is [0, tail_layer_start). */ +typedef struct { + uint32_t version; /* DS4_RPC_VERSION */ + uint32_t n_layer_total; /* DS4_N_LAYER, model-fixed */ + uint32_t n_embd; /* DS4_N_EMBD */ + uint32_t n_hc; /* DS4_N_HC */ + uint32_t n_vocab; /* DS4_N_VOCAB */ + uint32_t routed_quant_bits; /* 2 or 4 */ + uint32_t tail_layer_start; /* L_mid, the split point */ + uint32_t tail_layer_end; /* DS4_N_LAYER for a 2-machine setup */ + uint32_t ctx_size; /* must match between head and tail; tail's + KV cache is sized for this and overflows + if the head runs longer prompts */ + uint32_t tail_has_mtp; /* 1 if the worker has --mtp loaded, 0 else */ + uint32_t tail_mtp_draft_tokens; /* worker's --mtp-draft (1..16); ignored + when tail_has_mtp == 0 */ + uint32_t reserved0; + uint64_t model_file_bytes; /* size of GGUF on disk, fingerprint proxy */ + uint8_t model_sample[32]; /* first 32 bytes of GGUF (cheap fingerprint) */ +} ds4_rpc_config; + +#define DS4_RPC_MAX_DRAFTS 16 + +typedef struct ds4_rpc_handle ds4_rpc_handle; + +/* Lifecycle. */ +int ds4_rpc_dial(const char *host, uint16_t port, + ds4_rpc_handle **out, char *err, size_t errlen); +int ds4_rpc_listen_one(const char *bind_host, uint16_t port, + ds4_rpc_handle **out, char *err, size_t errlen); +void ds4_rpc_close(ds4_rpc_handle *h); +int ds4_rpc_fd(const ds4_rpc_handle *h); + +/* Handshake. The client sends its config; the server reads it, decides + * whether to accept, and replies with either ok or an error message. The + * server's accepted config is whatever it offered; the client should copy it + * back into the engine if proceeding. */ +int ds4_rpc_handshake_client(ds4_rpc_handle *h, const ds4_rpc_config *cfg, + char *err, size_t errlen); +/* Variant that also returns the server's accepted config in out_peer so the + * head can learn tail-side capabilities (MTP, etc.). */ +int ds4_rpc_handshake_client_peer(ds4_rpc_handle *h, + const ds4_rpc_config *cfg, + ds4_rpc_config *out_peer, + char *err, size_t errlen); +int ds4_rpc_handshake_server(ds4_rpc_handle *h, const ds4_rpc_config *cfg, + ds4_rpc_config *peer, char *err, size_t errlen); + +/* Head-side: ship one decode request and wait for the logits reply. + * out_drafts (capacity DS4_RPC_MAX_DRAFTS) is filled with the MTP draft + * tokens the tail produced, if any; *out_n_drafts is the count (may be 0). + * Pass NULL for out_drafts and 0 for max_drafts if the head doesn't want + * MTP support. */ +int ds4_rpc_decode_request(ds4_rpc_handle *h, + uint32_t token, uint32_t pos, + bool want_drafts, + const float *residual_hc, uint64_t n_residual_floats, + float *out_logits, uint64_t n_logit_floats, + uint32_t *out_drafts, uint32_t max_drafts, + uint32_t *out_n_drafts, + char *err, size_t errlen); + +/* Split halves of decode_request for pipelined operation: the head can ship + * the request, do work, then collect the reply later. Reply ordering on the + * socket is strict FIFO (tail processes serially); the head must call + * _recv_reply() once per outstanding _send(). */ +int ds4_rpc_decode_send(ds4_rpc_handle *h, + uint32_t token, uint32_t pos, + bool want_drafts, + const float *residual_hc, uint64_t n_residual_floats, + char *err, size_t errlen); +int ds4_rpc_decode_recv_reply(ds4_rpc_handle *h, + float *out_logits, uint64_t n_logit_floats, + uint32_t *out_drafts, uint32_t max_drafts, + uint32_t *out_n_drafts, + char *err, size_t errlen); + +/* Tail-side: receive one decode request (residual + metadata) and emit one + * decode reply (logits + optional MTP drafts). The two halves run in the + * worker's serve loop. */ +int ds4_rpc_decode_recv(ds4_rpc_handle *h, + uint32_t *token, uint32_t *pos, + bool *want_drafts, + float *residual_hc, uint64_t n_residual_floats, + char *err, size_t errlen); +int ds4_rpc_decode_reply(ds4_rpc_handle *h, + const float *logits, uint64_t n_logit_floats, + const uint32_t *drafts, uint32_t n_drafts, + char *err, size_t errlen); + +/* After speculative verification on the head, tell the tail how many of its + * MTP drafts were accepted so it can roll back the unused MTP cache rows. */ +int ds4_rpc_mtp_trim(ds4_rpc_handle *h, uint32_t accepted_drafts, + char *err, size_t errlen); +int ds4_rpc_mtp_trim_recv(ds4_rpc_handle *h, uint32_t *accepted_drafts, + char *err, size_t errlen); +int ds4_rpc_mtp_trim_reply(ds4_rpc_handle *h, char *err, size_t errlen); + +/* Batched all-or-nothing speculative verification. Head ships n_tokens + * residuals (its slice's hidden state for n_tokens consecutive draft tokens) + * plus the expected-next tokens. Tail runs batched prefill, computes per- + * row argmax via the output head, compares to expected, and either commits + * all n_tokens (returning logits at the final row for sampling the next- + * after-accepted) or reverts its KV via spec_frontier_snapshot/restore and + * returns accepted=0. The head does the same KV snapshot on its side. */ +int ds4_rpc_verify_batch_request(ds4_rpc_handle *h, + uint32_t n_tokens, uint32_t pos_start, + const float *batch_residual, + uint64_t n_residual_floats, + const uint32_t *expected_next, + uint32_t n_expected, + uint32_t *out_n_accepted, + float *out_logits, uint64_t n_logit_floats, + char *err, size_t errlen); +int ds4_rpc_verify_batch_recv(ds4_rpc_handle *h, + uint32_t *n_tokens, uint32_t *pos_start, + float *batch_residual, uint64_t max_residual_floats, + uint64_t *out_n_residual_floats, + uint32_t *expected_next, uint32_t max_expected, + uint32_t *out_n_expected, + char *err, size_t errlen); +int ds4_rpc_verify_batch_reply(ds4_rpc_handle *h, + uint32_t n_accepted, + const float *logits, uint64_t n_logit_floats, + char *err, size_t errlen); + +/* Head-side: ship one prefill chunk's batch residual and (optionally) wait + * for logits. n_tokens is the chunk size; pos_start is where this chunk + * begins in the absolute session position; want_logits is true only on the + * final chunk of the prompt. When want_logits is false, out_logits may be + * NULL and n_logit_floats may be 0. */ +int ds4_rpc_prefill_request(ds4_rpc_handle *h, + uint32_t n_tokens, uint32_t pos_start, + bool want_logits, + const float *batch_residual_hc, + uint64_t n_residual_floats, + float *out_logits, uint64_t n_logit_floats, + char *err, size_t errlen); + +/* Tail-side counterparts. ds4_rpc_prefill_recv writes the batch residual + * into the caller's buffer; ds4_rpc_prefill_reply ships back the requested + * logits (or empty reply if !want_logits). */ +int ds4_rpc_prefill_recv(ds4_rpc_handle *h, + uint32_t *n_tokens, uint32_t *pos_start, + bool *want_logits, + float *batch_residual_hc, uint64_t max_residual_floats, + uint64_t *out_n_residual_floats, + char *err, size_t errlen); +int ds4_rpc_prefill_reply(ds4_rpc_handle *h, + bool has_logits, + const float *logits, uint64_t n_logit_floats, + char *err, size_t errlen); + +/* Session control. ds4_rpc_recv_op peeks at the next frame without + * consuming it; the tail's serve loop uses it to dispatch DECODE_REQ + * (handled by ds4_rpc_decode_recv) vs RESET (ds4_rpc_reset_recv) vs + * SHUTDOWN (ds4_rpc_shutdown_recv). */ +int ds4_rpc_reset(ds4_rpc_handle *h, char *err, size_t errlen); +int ds4_rpc_recv_op(ds4_rpc_handle *h, ds4_rpc_op *op, + char *err, size_t errlen); +int ds4_rpc_reset_recv(ds4_rpc_handle *h, char *err, size_t errlen); +int ds4_rpc_reset_reply(ds4_rpc_handle *h, char *err, size_t errlen); +int ds4_rpc_shutdown_send(ds4_rpc_handle *h); +int ds4_rpc_shutdown_recv(ds4_rpc_handle *h, char *err, size_t errlen); + +/* Partial rewind: tell the tail to truncate its checkpoint to target_pos + * without dropping the whole session. Used by server-side tool-call + * canonicalization which rewinds a few tokens after sampling a tool block. */ +int ds4_rpc_rewind(ds4_rpc_handle *h, uint32_t target_pos, + char *err, size_t errlen); +int ds4_rpc_rewind_recv(ds4_rpc_handle *h, uint32_t *target_pos, + char *err, size_t errlen); +int ds4_rpc_rewind_reply(ds4_rpc_handle *h, char *err, size_t errlen); + +#endif diff --git a/ds4_rpc_worker.c b/ds4_rpc_worker.c new file mode 100644 index 00000000..82b7427a --- /dev/null +++ b/ds4_rpc_worker.c @@ -0,0 +1,567 @@ +/* ds4-rpc-worker: tail-side serve loop for pipeline-parallel inference. + * + * The worker owns layers [start, end) of one DS4 GGUF, listens on a TCP port, + * accepts a single head connection, performs a handshake, and then runs a + * decode-only request/reply loop until the head sends SHUTDOWN or the + * connection drops. Single-process scope on purpose: the head talks to one + * worker per session. Reconnect and concurrent serving are Phase 5. + * + * This binary intentionally has no inference logic of its own. It is the + * normal ds4 engine plus session, configured with a partial layer range, plus + * a thin transport adapter. Single-host ds4 / ds4-server keep working + * unchanged whether or not a worker is running. */ + +#define _POSIX_C_SOURCE 200809L +#define _DARWIN_C_SOURCE +#define _BSD_SOURCE +#define _DEFAULT_SOURCE + +#include "ds4.h" +#include "ds4_rpc.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WORKER_DEFAULT_PORT 46434u /* "GD43" loosely, easy to remember */ +#define WORKER_DEFAULT_CTX 4096 + +typedef struct { + const char *model_path; + const char *quant; /* --quant q2|q4, NULL if unspecified */ + const char *mtp_path; /* optional MTP GGUF for speculative drafting */ + int mtp_draft_tokens; /* max drafts per speculative step */ + float mtp_margin; /* confidence margin for fast verifier */ + const char *bind_host; + uint16_t port; + int layer_start; + int layer_end; + int ctx_size; + int routed_quant_bits; /* 2 or 4; if 0, infer from filename */ + int n_layer_total; /* expected DS4_N_LAYER; for handshake */ + int n_embd; + int n_hc; + int n_vocab; +} worker_opts; + +static void worker_usage(const char *prog) { + fprintf(stderr, + "Usage: %s --layer-start N --layer-end M [options]\n" + "\n" + "Required:\n" + " --layer-start N First layer owned (inclusive). Typical: 21.\n" + " --layer-end M Last layer owned (exclusive). Typical: 43.\n" + "\n" + "Common:\n" + " -m, --model FILE GGUF path. Wins over --quant. Default:\n" + " auto-detect Q2 or Q4 in ./gguf/ (prefers Q2).\n" + " --quant Q Pick canonical 'q2' or 'q4' file in ./gguf/.\n" + " Must match the head's --quant for the\n" + " handshake fingerprint to succeed.\n" + " --listen HOST Bind address. Default: 0.0.0.0\n" + " --port N TCP port. Default: %u\n" + " --ctx N Context size for session. Default: %u\n" + " --routed-quant-bits N 2 or 4; declared in handshake so the head\n" + " can refuse if it mixed q2 and q4 weights.\n" + " --mtp [FILE] Enable MTP for speculative decoding. With\n" + " a FILE argument, load that GGUF; bare --mtp\n" + " resolves to the canonical MTP path in\n" + " ./gguf/ (from ./download_model.sh mtp).\n" + " When set, the worker runs MTP drafts after\n" + " each decode and ships them to the head.\n" + " --mtp-draft N Max draft tokens per speculative step (1-16).\n" + " --mtp-margin F Confidence margin for the fast verifier.\n" + " -h, --help This help.\n", + prog, WORKER_DEFAULT_PORT, WORKER_DEFAULT_CTX); +} + +static int parse_int_arg(const char *flag, const char *val, int *out) { + if (!val || !val[0]) { + fprintf(stderr, "ds4-rpc-worker: %s requires a value\n", flag); + return 1; + } + char *end = NULL; + long v = strtol(val, &end, 10); + if (end == val || *end != '\0') { + fprintf(stderr, "ds4-rpc-worker: %s: not an integer: %s\n", flag, val); + return 1; + } + *out = (int)v; + return 0; +} + +static int parse_args(int argc, char **argv, worker_opts *o) { + o->model_path = NULL; /* resolved after parsing via ds4_resolve_model_path */ + o->quant = NULL; + o->bind_host = "0.0.0.0"; + o->port = WORKER_DEFAULT_PORT; + o->layer_start = -1; + o->layer_end = -1; + o->ctx_size = WORKER_DEFAULT_CTX; + o->routed_quant_bits = 0; + o->mtp_path = NULL; + o->mtp_draft_tokens = 1; + o->mtp_margin = 0.0f; + + for (int i = 1; i < argc; i++) { + const char *a = argv[i]; + const char *next = (i + 1 < argc) ? argv[i + 1] : NULL; + if (!strcmp(a, "-h") || !strcmp(a, "--help")) { + worker_usage(argv[0]); + return -1; + } else if (!strcmp(a, "-m") || !strcmp(a, "--model")) { + if (!next) { fprintf(stderr, "%s needs a value\n", a); return 1; } + o->model_path = next; i++; + } else if (!strcmp(a, "--quant")) { + if (!next) { fprintf(stderr, "%s needs a value\n", a); return 1; } + o->quant = next; i++; + } else if (!strcmp(a, "--listen")) { + if (!next) { fprintf(stderr, "%s needs a value\n", a); return 1; } + o->bind_host = next; i++; + } else if (!strcmp(a, "--port")) { + int p = 0; + if (parse_int_arg(a, next, &p)) return 1; + if (p < 1 || p > 65535) { fprintf(stderr, "--port out of range\n"); return 1; } + o->port = (uint16_t)p; i++; + } else if (!strcmp(a, "--ctx")) { + if (parse_int_arg(a, next, &o->ctx_size)) return 1; + i++; + } else if (!strcmp(a, "--layer-start")) { + if (parse_int_arg(a, next, &o->layer_start)) return 1; + i++; + } else if (!strcmp(a, "--layer-end")) { + if (parse_int_arg(a, next, &o->layer_end)) return 1; + i++; + } else if (!strcmp(a, "--routed-quant-bits")) { + if (parse_int_arg(a, next, &o->routed_quant_bits)) return 1; + i++; + } else if (!strcmp(a, "--mtp")) { + /* Accept either "--mtp PATH" or bare "--mtp" (resolves to the + * canonical MTP GGUF in ./gguf/ via ds4_resolve_mtp_path). */ + if (next && next[0] && next[0] != '-') { + o->mtp_path = next; i++; + } else { + o->mtp_path = "auto"; + } + } else if (!strcmp(a, "--mtp-draft")) { + if (parse_int_arg(a, next, &o->mtp_draft_tokens)) return 1; + if (o->mtp_draft_tokens < 1) o->mtp_draft_tokens = 1; + if (o->mtp_draft_tokens > 16) o->mtp_draft_tokens = 16; + i++; + } else if (!strcmp(a, "--mtp-margin")) { + if (!next) { fprintf(stderr, "%s needs a value\n", a); return 1; } + char *endp = NULL; + float v = strtof(next, &endp); + if (endp == next) { + fprintf(stderr, "ds4-rpc-worker: --mtp-margin not a float: %s\n", next); + return 1; + } + o->mtp_margin = v; + i++; + } else { + fprintf(stderr, "ds4-rpc-worker: unknown argument: %s\n", a); + worker_usage(argv[0]); + return 1; + } + } + if (o->layer_start < 0 || o->layer_end <= 0) { + fprintf(stderr, "ds4-rpc-worker: --layer-start and --layer-end are required\n"); + worker_usage(argv[0]); + return 1; + } + if (o->layer_start >= o->layer_end) { + fprintf(stderr, "ds4-rpc-worker: empty layer range [%d, %d)\n", + o->layer_start, o->layer_end); + return 1; + } + + /* Final model-path resolution: -m wins, else --quant, else filesystem + * probe (Q2 preferred), else fall back to ds4flash.gguf symlink. The + * resolved path must match what the head ships in its handshake config + * for the fingerprint check to pass. */ + { + char resolve_err[256] = {0}; + const char *resolved = ds4_resolve_model_path(o->model_path, o->quant, + resolve_err, sizeof(resolve_err)); + if (!resolved) { + fprintf(stderr, "ds4-rpc-worker: %s\n", resolve_err); + return 1; + } + if (resolve_err[0]) fprintf(stderr, "ds4-rpc-worker: %s\n", resolve_err); + o->model_path = resolved; + } + /* MTP path: same pattern. If --mtp wasn't passed, mtp_path stays NULL + * and the engine just won't load MTP. If --mtp was passed (with or + * without an explicit path), resolve to either the explicit path or + * the canonical gguf/.gguf. */ + if (o->mtp_path) { + char resolve_err[256] = {0}; + const char *resolved = ds4_resolve_mtp_path(o->mtp_path, + resolve_err, sizeof(resolve_err)); + if (!resolved) { + fprintf(stderr, "ds4-rpc-worker: %s\n", + resolve_err[0] ? resolve_err : + "--mtp requested but no MTP GGUF found in ./gguf/"); + return 1; + } + o->mtp_path = resolved; + fprintf(stderr, "ds4-rpc-worker: MTP path resolved to %s\n", resolved); + } + return 0; +} + +/* Read file size and the first 32 bytes for the cheap fingerprint embedded + * in the handshake. Both sides compute the same thing; a difference rejects + * the connection. */ +static int load_model_fingerprint(const char *path, uint64_t *out_size, + uint8_t out_sample[32]) { + struct stat st; + if (stat(path, &st) != 0) { + fprintf(stderr, "ds4-rpc-worker: stat(%s): %s\n", path, strerror(errno)); + return 1; + } + *out_size = (uint64_t)st.st_size; + + int fd = open(path, O_RDONLY); + if (fd < 0) { + fprintf(stderr, "ds4-rpc-worker: open(%s): %s\n", path, strerror(errno)); + return 1; + } + ssize_t r = read(fd, out_sample, 32); + close(fd); + if (r != 32) { + fprintf(stderr, "ds4-rpc-worker: short read of %s for fingerprint\n", path); + return 1; + } + return 0; +} + +int main(int argc, char **argv) { + worker_opts opts; + int parse = parse_args(argc, argv, &opts); + if (parse < 0) return 0; + if (parse != 0) return parse; + + fprintf(stderr, + "ds4-rpc-worker: model=%s range=[%d, %d) listen=%s:%u ctx=%d\n", + opts.model_path, opts.layer_start, opts.layer_end, + opts.bind_host, (unsigned)opts.port, opts.ctx_size); + + /* Open the engine with the partial layer range. Only the tail's layers + * are bound; the head's globals (token_embd) are skipped, and validation + * is gated to match. */ + ds4_engine_options eopt = { + .model_path = opts.model_path, + .mtp_path = opts.mtp_path, + .backend = DS4_BACKEND_METAL, + .n_threads = 0, + .mtp_draft_tokens = opts.mtp_draft_tokens, + .mtp_margin = opts.mtp_margin, + .warm_weights = false, + .quality = true, + .n_layer_start = opts.layer_start, + .n_layer_end = opts.layer_end, + }; + ds4_engine *engine = NULL; + if (ds4_engine_open(&engine, &eopt) != 0 || !engine) { + fprintf(stderr, "ds4-rpc-worker: failed to open engine\n"); + return 1; + } + + /* Build our half of the handshake config from the just-loaded engine and + * a quick file-fingerprint read. */ + uint64_t model_bytes = 0; + uint8_t model_sample[32] = {0}; + if (load_model_fingerprint(opts.model_path, &model_bytes, model_sample) != 0) { + ds4_engine_close(engine); + return 1; + } + const int quant_bits = opts.routed_quant_bits != 0 + ? opts.routed_quant_bits + : ds4_engine_routed_quant_bits(engine); + + const bool has_mtp = ds4_engine_has_mtp(engine); + ds4_rpc_config cfg = { + .version = DS4_RPC_VERSION, + .n_layer_total = ds4_model_n_layer(), + .n_embd = ds4_model_n_embd(), + .n_hc = ds4_model_n_hc(), + .n_vocab = ds4_model_n_vocab(), + .routed_quant_bits = (uint32_t)quant_bits, + .tail_layer_start = (uint32_t)opts.layer_start, + .tail_layer_end = (uint32_t)opts.layer_end, + .ctx_size = (uint32_t)opts.ctx_size, + .tail_has_mtp = has_mtp ? 1u : 0u, + .tail_mtp_draft_tokens = has_mtp ? (uint32_t)opts.mtp_draft_tokens : 0u, + .model_file_bytes = model_bytes, + }; + if (has_mtp) { + fprintf(stderr, + "ds4-rpc-worker: MTP enabled (max drafts/step = %d, margin=%.3f)\n", + opts.mtp_draft_tokens, (double)opts.mtp_margin); + } + memcpy(cfg.model_sample, model_sample, 32); + + /* Listen for the head; accept one connection. */ + ds4_rpc_handle *rpc = NULL; + char err[512] = {0}; + if (ds4_rpc_listen_one(opts.bind_host, opts.port, &rpc, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: listen: %s\n", err); + ds4_engine_close(engine); + return 1; + } + ds4_rpc_config peer = {0}; + if (ds4_rpc_handshake_server(rpc, &cfg, &peer, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: handshake: %s\n", err); + ds4_rpc_close(rpc); + ds4_engine_close(engine); + return 1; + } + fprintf(stderr, "ds4-rpc-worker: handshake ok, ready to serve\n"); + + /* Create the tail session. KV state for our owned layers lives here and + * persists across decode requests until the head sends RESET. */ + ds4_session *session = NULL; + if (ds4_session_create(&session, engine, opts.ctx_size) != 0 || !session) { + fprintf(stderr, "ds4-rpc-worker: session create failed\n"); + ds4_rpc_close(rpc); + ds4_engine_close(engine); + return 1; + } + + const uint64_t n_residual = ds4_residual_hc_floats(); + const uint64_t n_vocab = (uint64_t)cfg.n_vocab; + const uint32_t prefill_cap = ds4_session_prefill_cap(session); + const uint64_t batch_residual_floats = (uint64_t)prefill_cap * n_residual; + float *residual = (float *)malloc((size_t)(n_residual * sizeof(float))); + float *logits = (float *)malloc((size_t)(n_vocab * sizeof(float))); + float *batch_residual = (float *)malloc((size_t)(batch_residual_floats * sizeof(float))); + if (!residual || !logits || !batch_residual) { + fprintf(stderr, "ds4-rpc-worker: alloc failed\n"); + free(residual); free(logits); free(batch_residual); + ds4_session_free(session); + ds4_rpc_close(rpc); + ds4_engine_close(engine); + return 1; + } + fprintf(stderr, + "ds4-rpc-worker: prefill scratch %.2f MiB (prefill_cap=%u tokens)\n", + (double)(batch_residual_floats * sizeof(float)) / (1024.0 * 1024.0), + prefill_cap); + + /* Serve loop. */ + int rc = 0; + bool running = true; + uint64_t served = 0; + while (running) { + ds4_rpc_op op = 0; + if (ds4_rpc_recv_op(rpc, &op, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: peek op: %s\n", err); + rc = 1; + break; + } + switch (op) { + case DS4_RPC_OP_DECODE_REQ: { + uint32_t token = 0, pos = 0; + bool want_drafts = false; + if (ds4_rpc_decode_recv(rpc, &token, &pos, &want_drafts, + residual, n_residual, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: decode recv: %s\n", err); + rc = 1; running = false; break; + } + + if (ds4_session_import_residual_hc(session, residual, n_residual, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: import residual: %s\n", err); + (void)ds4_rpc_decode_reply(rpc, NULL, n_vocab, NULL, 0, NULL, 0); + rc = 1; running = false; break; + } + + const int eval_rc = want_drafts + ? ds4_session_eval(session, (int)token, err, sizeof(err)) + : ds4_session_eval_no_draft(session, (int)token, err, sizeof(err)); + if (eval_rc != 0) { + fprintf(stderr, "ds4-rpc-worker: session eval: %s\n", err); + (void)ds4_rpc_decode_reply(rpc, NULL, n_vocab, NULL, 0, NULL, 0); + rc = 1; running = false; break; + } + + const float *src = ds4_session_logits(session); + if (!src) { + fprintf(stderr, "ds4-rpc-worker: session has no logits after eval\n"); + (void)ds4_rpc_decode_reply(rpc, NULL, n_vocab, NULL, 0, NULL, 0); + rc = 1; running = false; break; + } + memcpy(logits, src, (size_t)(n_vocab * sizeof(float))); + + uint32_t drafts[DS4_RPC_MAX_DRAFTS] = {0}; + uint32_t n_drafts = 0; + if (want_drafts && ds4_engine_has_mtp(engine)) { + int produced = ds4_session_mtp_drafts_after_eval( + session, drafts, opts.mtp_draft_tokens, + err, sizeof(err)); + if (produced > 0) n_drafts = (uint32_t)produced; + else if (err[0]) { + fprintf(stderr, "ds4-rpc-worker: mtp drafts: %s\n", err); + } + } + + if (ds4_rpc_decode_reply(rpc, logits, n_vocab, + n_drafts > 0 ? drafts : NULL, n_drafts, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: reply: %s\n", err); + rc = 1; running = false; break; + } + served++; + break; + } + + case DS4_RPC_OP_VERIFY_BATCH: { + uint32_t v_n_tokens = 0, v_pos_start = 0, v_n_expected = 0; + uint64_t v_n_residual = 0; + uint32_t expected[16] = {0}; + if (ds4_rpc_verify_batch_recv(rpc, &v_n_tokens, &v_pos_start, + batch_residual, batch_residual_floats, + &v_n_residual, + expected, 16, &v_n_expected, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: verify_batch recv: %s\n", err); + rc = 1; running = false; break; + } + uint32_t n_accepted = 0; + if (ds4_session_verify_batch_imported_hc(session, batch_residual, + v_n_tokens, v_pos_start, + expected, v_n_expected, + &n_accepted, + logits, n_vocab, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: verify_batch eval: %s\n", err); + (void)ds4_rpc_verify_batch_reply(rpc, 0, NULL, 0, NULL, 0); + rc = 1; running = false; break; + } + if (ds4_rpc_verify_batch_reply(rpc, n_accepted, + n_accepted > 0 ? logits : NULL, + n_accepted > 0 ? n_vocab : 0, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: verify_batch reply: %s\n", err); + rc = 1; running = false; break; + } + break; + } + + case DS4_RPC_OP_MTP_TRIM: { + uint32_t accepted = 0; + if (ds4_rpc_mtp_trim_recv(rpc, &accepted, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: mtp_trim recv: %s\n", err); + rc = 1; running = false; break; + } + ds4_session_mtp_trim_drafts(session, accepted); + if (ds4_rpc_mtp_trim_reply(rpc, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: mtp_trim reply: %s\n", err); + rc = 1; running = false; break; + } + break; + } + + case DS4_RPC_OP_PREFILL_REQ: { + uint32_t n_tok = 0, pos_start = 0; + bool want_logits = false; + uint64_t n_recv_floats = 0; + if (ds4_rpc_prefill_recv(rpc, &n_tok, &pos_start, &want_logits, + batch_residual, batch_residual_floats, + &n_recv_floats, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: prefill recv: %s\n", err); + rc = 1; running = false; break; + } + + if (ds4_session_eval_batch_imported_hc(session, batch_residual, + n_tok, pos_start, want_logits, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: prefill eval: %s\n", err); + (void)ds4_rpc_prefill_reply(rpc, false, NULL, 0, NULL, 0); + rc = 1; running = false; break; + } + + const float *src = want_logits ? ds4_session_logits(session) : NULL; + if (want_logits && !src) { + fprintf(stderr, "ds4-rpc-worker: prefill eval produced no logits\n"); + (void)ds4_rpc_prefill_reply(rpc, false, NULL, 0, NULL, 0); + rc = 1; running = false; break; + } + if (want_logits) { + memcpy(logits, src, (size_t)(n_vocab * sizeof(float))); + } + if (ds4_rpc_prefill_reply(rpc, want_logits, + want_logits ? logits : NULL, + want_logits ? n_vocab : 0, + err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: prefill reply: %s\n", err); + rc = 1; running = false; break; + } + break; + } + + case DS4_RPC_OP_REWIND: { + uint32_t target = 0; + if (ds4_rpc_rewind_recv(rpc, &target, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: rewind recv: %s\n", err); + rc = 1; running = false; break; + } + ds4_session_rewind(session, (int)target); + if (ds4_rpc_rewind_reply(rpc, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: rewind reply: %s\n", err); + rc = 1; running = false; break; + } + break; + } + + case DS4_RPC_OP_RESET: { + if (ds4_rpc_reset_recv(rpc, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: reset recv: %s\n", err); + rc = 1; running = false; break; + } + ds4_session_free(session); + session = NULL; + if (ds4_session_create(&session, engine, opts.ctx_size) != 0 || !session) { + fprintf(stderr, "ds4-rpc-worker: session recreate failed\n"); + rc = 1; running = false; break; + } + if (ds4_rpc_reset_reply(rpc, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4-rpc-worker: reset reply: %s\n", err); + rc = 1; running = false; break; + } + break; + } + + case DS4_RPC_OP_SHUTDOWN: { + (void)ds4_rpc_shutdown_recv(rpc, err, sizeof(err)); + fprintf(stderr, "ds4-rpc-worker: head requested shutdown\n"); + running = false; + break; + } + + default: + fprintf(stderr, "ds4-rpc-worker: unknown op %u\n", (unsigned)op); + rc = 1; running = false; break; + } + } + + fprintf(stderr, "ds4-rpc-worker: served %llu decode requests\n", + (unsigned long long)served); + + free(residual); + free(logits); + free(batch_residual); + if (session) ds4_session_free(session); + ds4_rpc_close(rpc); + ds4_engine_close(engine); + return rc; +} diff --git a/ds4_server.c b/ds4_server.c index b6933a92..c3829383 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -5922,6 +5922,10 @@ static bool kv_cache_store_live_prefix(server *s, const ds4_tokens *tokens, int store_len, const char *reason) { kv_disk_cache *kc = &s->kv; if (!kc->enabled) return false; + /* Under pipeline-parallel RPC the head engine only owns half the layers, + * so a snapshot would cover an incomplete graph -- useless for resume + * and confusing for the load path which expects full state. */ + if (ds4_engine_has_rpc_peer(s->engine)) return false; if (!tokens || store_len < kc->opt.min_tokens) return false; const int original_len = tokens->len; @@ -7112,7 +7116,10 @@ static void generate_job(server *s, job *j) { if (in_tool_call && !dsml_decode_state_uses_payload_sampling(dsml_state)) { temperature = 0.0f; } + const bool spec_debug = getenv("DS4_RPC_SPEC_DEBUG") != NULL; + if (spec_debug) fprintf(stderr, "ds4-spec: server: about to sample\n"); int token = ds4_session_sample(s->session, temperature, top_k, top_p, min_p, &rng); + if (spec_debug) fprintf(stderr, "ds4-spec: server: sampled token=%d\n", token); if (token == ds4_token_eos(s->engine)) { finish = "stop"; break; @@ -7124,6 +7131,7 @@ static void generate_job(server *s, job *j) { ds4_engine_mtp_draft_tokens(s->engine) > 1 && getenv("DS4_MTP_SPEC_DISABLE") == NULL) { + if (spec_debug) fprintf(stderr, "ds4-spec: server: calling speculative_argmax(token=%d)\n", token); ntok = ds4_session_eval_speculative_argmax(s->session, token, max_tokens - completion, @@ -7132,6 +7140,7 @@ static void generate_job(server *s, job *j) { (int)(sizeof(toks) / sizeof(toks[0])), err, sizeof(err)); + if (spec_debug) fprintf(stderr, "ds4-spec: server: speculative_argmax returned ntok=%d\n", ntok); if (ntok < 0) { finish = "error"; break; @@ -7795,6 +7804,7 @@ typedef struct { bool kv_cache_reject_different_quant; bool disable_exact_dsml_tool_replay; int tool_memory_max_ids; + const char *quant; /* --quant q2|q4, NULL if unspecified */ } server_config; static int parse_int_arg(const char *s, const char *opt) { @@ -7870,7 +7880,14 @@ static void usage(FILE *fp) { "\n" "Model and runtime:\n" " -m, --model FILE\n" - " GGUF model path. Default: ds4flash.gguf\n" + " GGUF model path. Wins over --quant. Default: auto-detect Q2 or Q4\n" + " in ./gguf/, preferring Q2 if both are present; fall back to\n" + " ds4flash.gguf when neither is found.\n" + " --quant Q\n" + " Pick the canonical Q2 or Q4 file in ./gguf/ by name. Use 'q2'\n" + " for the 128 GB-friendly 86 GB model; use 'q4' only if you have\n" + " >=256 GB or are running pipeline-parallel with --rpc-peer.\n" + " Ignored when -m is also given.\n" " --mtp FILE\n" " Optional MTP support GGUF used for draft-token probes.\n" " --mtp-draft N\n" @@ -7947,7 +7964,7 @@ static void usage(FILE *fp) { static server_config parse_options(int argc, char **argv) { server_config c = { .engine = { - .model_path = "ds4flash.gguf", + .model_path = NULL, /* resolved after parsing via ds4_resolve_model_path */ .backend = DS4_BACKEND_METAL, .mtp_draft_tokens = 1, .mtp_margin = 3.0f, @@ -7967,8 +7984,17 @@ static server_config parse_options(int argc, char **argv) { exit(0); } else if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) { c.engine.model_path = need_arg(&i, argc, argv, arg); + } else if (!strcmp(arg, "--quant")) { + c.quant = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--mtp")) { - c.engine.mtp_path = need_arg(&i, argc, argv, arg); + /* Accept either "--mtp PATH" or bare "--mtp" (resolves to the + * canonical MTP GGUF in ./gguf/). */ + const char *next = (i + 1 < argc) ? argv[i + 1] : NULL; + if (next && next[0] && next[0] != '-') { + c.engine.mtp_path = next; i++; + } else { + c.engine.mtp_path = "auto"; + } } else if (!strcmp(arg, "--mtp-draft")) { c.engine.mtp_draft_tokens = parse_int_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--mtp-margin")) { @@ -8009,6 +8035,37 @@ static server_config parse_options(int argc, char **argv) { c.engine.quality = true; } else if (!strcmp(arg, "--warm-weights")) { c.engine.warm_weights = true; + } else if (!strcmp(arg, "--rpc-peer")) { + /* Pipeline-parallel: ship the second half of layers to a tail + * worker reachable at host:port (default port 46434). Use with + * --rpc-split to set the boundary. */ + if (++i >= argc) { + server_log(DS4_LOG_DEFAULT, "ds4-server: --rpc-peer needs a value"); + exit(2); + } + const char *spec = argv[i]; + const char *colon = strrchr(spec, ':'); + if (colon && colon != spec) { + size_t host_len = (size_t)(colon - spec); + char *host = (char *)malloc(host_len + 1); + if (!host) { + server_log(DS4_LOG_DEFAULT, "ds4-server: oom parsing --rpc-peer"); + exit(2); + } + memcpy(host, spec, host_len); + host[host_len] = '\0'; + c.engine.rpc_peer_host = host; + c.engine.rpc_peer_port = (int)strtol(colon + 1, NULL, 10); + } else { + c.engine.rpc_peer_host = spec; + c.engine.rpc_peer_port = 46434; + } + } else if (!strcmp(arg, "--rpc-split")) { + if (++i >= argc) { + server_log(DS4_LOG_DEFAULT, "ds4-server: --rpc-split needs a value"); + exit(2); + } + c.engine.n_layer_end = parse_int_arg(argv[i], arg); } else if (!strcmp(arg, "--cpu") || !strcmp(arg, "--backend")) { server_log(DS4_LOG_DEFAULT, "ds4-server: server mode is Metal-only"); exit(2); @@ -8025,6 +8082,37 @@ static server_config parse_options(int argc, char **argv) { "ds4-server: --kv-cache-cold-max-tokens must be 0 or >= --kv-cache-min-tokens"); exit(2); } + + /* Final model-path resolution. Priority: -m wins, then --quant, then + * filesystem probe of the canonical Q2/Q4 paths (preferring Q2), then + * the historical ds4flash.gguf symlink fallback. On a 128 GB Mac Q4 is + * too big to load on its own and can kernel-panic macOS, so the probe + * defaults to Q2 when both files are present. */ + { + char resolve_err[256] = {0}; + const char *resolved = ds4_resolve_model_path(c.engine.model_path, + c.quant, + resolve_err, sizeof(resolve_err)); + if (!resolved) { + server_log(DS4_LOG_DEFAULT, "ds4-server: %s", resolve_err); + exit(2); + } + if (resolve_err[0]) server_log(DS4_LOG_DEFAULT, "ds4-server: %s", resolve_err); + c.engine.model_path = resolved; + } + if (c.engine.mtp_path) { + char resolve_err[256] = {0}; + const char *resolved = ds4_resolve_mtp_path(c.engine.mtp_path, + resolve_err, sizeof(resolve_err)); + if (!resolved) { + server_log(DS4_LOG_DEFAULT, "ds4-server: %s", + resolve_err[0] ? resolve_err : + "--mtp requested but no MTP GGUF found in ./gguf/"); + exit(2); + } + c.engine.mtp_path = resolved; + } + return c; } @@ -8040,6 +8128,9 @@ int main(int argc, char **argv) { server_config cfg = parse_options(argc, argv); + /* Propagate --ctx into the engine so the RPC handshake can assert + * head and tail agree on KV window size. No-op for single-host. */ + cfg.engine.rpc_ctx_size = cfg.ctx_size; ds4_engine *engine = NULL; if (ds4_engine_open(&engine, &cfg.engine) != 0) return 1; diff --git a/tests/ds4_test.c b/tests/ds4_test.c index d41ff74a..7989eede 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -560,6 +560,155 @@ static void test_tool_call_quality(void) { #endif +#ifndef DS4_NO_METAL +/* Pipeline-parallel correctness gate: a single-host engine [0, 43) and a + * daisy chain of head [0, L_mid) + tail [L_mid, 43) should produce the same + * top-k logprobs for the same first token. Bit-identical isn't guaranteed + * because Metal command-buffer split points differ and GPU reductions are + * not commutative under reorder, so we assert top-3 token agreement and a + * tight tolerance on logit values. All three engines coexist in this same + * process under the refcounted instance lock. */ +static void test_pipeline_daisy_chain(void) { + const char *model = test_model_path(); + const int ctx_size = 4096; + char err[512]; + + /* 1. Reference: full-range single-host engine, tokenize "Hello" to get + * a deterministic first token, eval one step, snapshot top logprobs. */ + ds4_engine_options opt_full = { + .model_path = model, + .backend = DS4_BACKEND_METAL, + }; + ds4_engine *e_full = NULL; + TEST_ASSERT(ds4_engine_open(&e_full, &opt_full) == 0); + if (!e_full) return; + + ds4_tokens tokens = {0}; + ds4_tokenize_text(e_full, "Hello", &tokens); + TEST_ASSERT(tokens.len > 0); + if (tokens.len == 0) { + ds4_engine_close(e_full); + return; + } + const int token = tokens.v[0]; + ds4_tokens_free(&tokens); + + ds4_session *s_full = NULL; + TEST_ASSERT(ds4_session_create(&s_full, e_full, ctx_size) == 0); + if (!s_full) { ds4_engine_close(e_full); return; } + + err[0] = '\0'; + TEST_ASSERT(ds4_session_eval(s_full, token, err, sizeof(err)) == 0); + if (err[0]) fprintf(stderr, "pipeline daisy chain (full eval): %s\n", err); + + ds4_token_score top_full[16]; + const int k_full = ds4_session_top_logprobs(s_full, top_full, 16); + TEST_ASSERT(k_full > 0); + + ds4_session_free(s_full); + ds4_engine_close(e_full); + + /* 2. Head engine: [0, L_mid). */ + const int L_mid = 21; + + ds4_engine_options opt_head = { + .model_path = model, + .backend = DS4_BACKEND_METAL, + .n_layer_start = 0, + .n_layer_end = L_mid, + }; + ds4_engine *e_head = NULL; + TEST_ASSERT(ds4_engine_open(&e_head, &opt_head) == 0); + if (!e_head) return; + + ds4_session *s_head = NULL; + TEST_ASSERT(ds4_session_create(&s_head, e_head, ctx_size) == 0); + if (!s_head) { ds4_engine_close(e_head); return; } + + err[0] = '\0'; + TEST_ASSERT(ds4_session_eval(s_head, token, err, sizeof(err)) == 0); + if (err[0]) fprintf(stderr, "pipeline daisy chain (head eval): %s\n", err); + + const uint64_t n = ds4_residual_hc_floats(); + float *residual = malloc((size_t)n * sizeof(float)); + TEST_ASSERT(residual != NULL); + if (!residual) { + ds4_session_free(s_head); + ds4_engine_close(e_head); + return; + } + + err[0] = '\0'; + TEST_ASSERT(ds4_session_export_residual_hc(s_head, residual, n, err, sizeof(err)) == 0); + if (err[0]) fprintf(stderr, "pipeline daisy chain (export): %s\n", err); + + ds4_session_free(s_head); + ds4_engine_close(e_head); + + /* 3. Tail engine: [L_mid, DS4_N_LAYER). Import head's residual, then + * eval one step: encode skips embed (n_layer_start > 0), runs the + * remaining layers, and applies the output projection (n_layer_end == + * DS4_N_LAYER), so s_tail->logits ends up as the daisy-chain logits. */ + ds4_engine_options opt_tail = { + .model_path = model, + .backend = DS4_BACKEND_METAL, + .n_layer_start = L_mid, + .n_layer_end = 43, + }; + ds4_engine *e_tail = NULL; + TEST_ASSERT(ds4_engine_open(&e_tail, &opt_tail) == 0); + if (!e_tail) { free(residual); return; } + + ds4_session *s_tail = NULL; + TEST_ASSERT(ds4_session_create(&s_tail, e_tail, ctx_size) == 0); + if (!s_tail) { + free(residual); + ds4_engine_close(e_tail); + return; + } + + err[0] = '\0'; + TEST_ASSERT(ds4_session_import_residual_hc(s_tail, residual, n, err, sizeof(err)) == 0); + if (err[0]) fprintf(stderr, "pipeline daisy chain (import): %s\n", err); + + err[0] = '\0'; + TEST_ASSERT(ds4_session_eval(s_tail, token, err, sizeof(err)) == 0); + if (err[0]) fprintf(stderr, "pipeline daisy chain (tail eval): %s\n", err); + + ds4_token_score top_tail[16]; + const int k_tail = ds4_session_top_logprobs(s_tail, top_tail, 16); + TEST_ASSERT(k_tail > 0); + + /* 4. Compare. Bit-exact reductions across distinct command-buffer + * shapes are not guaranteed; require top-3 token agreement and a small + * logit tolerance for the wider window. */ + const int k = k_full < k_tail ? k_full : k_tail; + int mismatches = 0; + float max_logit_diff = 0.0f; + for (int i = 0; i < k; i++) { + if (i < 3 && top_full[i].id != top_tail[i].id) { + mismatches++; + fprintf(stderr, + "pipeline daisy chain: top-%d token mismatch full=%d tail=%d " + "(logits %g vs %g)\n", + i, top_full[i].id, top_tail[i].id, + (double)top_full[i].logit, (double)top_tail[i].logit); + } + const float d = fabsf(top_full[i].logit - top_tail[i].logit); + if (d > max_logit_diff) max_logit_diff = d; + } + fprintf(stderr, + "pipeline daisy chain: token=%d split_at=%d compared=%d top3_mismatches=%d max_diff=%g\n", + token, L_mid, k, mismatches, (double)max_logit_diff); + TEST_ASSERT(mismatches == 0); + TEST_ASSERT(max_logit_diff < 1.0e-1f); + + free(residual); + ds4_session_free(s_tail); + ds4_engine_close(e_tail); +} +#endif + static void test_server_unit_group(void) { ds4_server_unit_tests_run(); } @@ -578,6 +727,7 @@ static const ds4_test_entry test_entries[] = { {"--long-context", "long-context", "long Metal continuation regression", test_long_security_continuation}, {"--tool-call-quality", "tool-call-quality", "model emits valid DSML tool calls", test_tool_call_quality}, {"--logprob-vectors", "logprob-vectors", "official API top-logprob vector comparison", test_official_logprob_vectors}, + {"--pipeline-daisy-chain", "pipeline-daisy-chain", "head/tail split produces matching top logprobs", test_pipeline_daisy_chain}, {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_f16_matvec_fast_nr0_4}, #endif {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group},