diff --git a/CLAUDE.md b/CLAUDE.md index e3066a9..6fe9f3a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,3 +1,103 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +**NOTE:** README.md is a symlink to this file. Keep content useful for both Claude Code and GitHub readers. + +## Build & Run + +Two backends: `metal_infer/` for Apple Silicon, `cuda_infer/` for NVIDIA GPUs. + +### CUDA backend (NVIDIA GPUs) + +See [`cuda_infer/README.md`](cuda_infer/README.md) for full documentation. Quick start: + +```bash +cd cuda_infer +make # requires CUDA 12.8+ and libcufile +./infer --prompt "Hello" --tokens 20 +``` + +### Metal backend (Apple Silicon) + +All binaries are built from `metal_infer/`. Metal shaders compile at runtime (no offline metal compiler needed). + +```bash +cd metal_infer +make # builds metal_infer (benchmark) + infer (inference engine) +make chat # builds chat TUI client (separate target) +make clean # remove build artifacts +``` + +### Inference engine (`infer`) + +```bash +./infer --prompt "Hello" --tokens 50 # basic generation +./infer --prompt "Hello" --tokens 50 --2bit # 2-bit mode (faster, breaks JSON) +./infer --prompt "Hello" --tokens 20 --timing # per-layer timing breakdown +./infer --serve 8080 # HTTP server (OpenAI-compatible API) +./infer --prompt "Hello" --tokens 20 --freq # expert frequency tracking +./infer --prompt "Hello" --tokens 20 --cache-telemetry # cold vs eviction miss analysis +``` + +### Chat TUI (`chat`) + +Thin HTTP/SSE client that connects to the inference server. Sessions persist to `~/.flash-moe/sessions/.jsonl`. + +```bash +./chat # connect to default port +./chat --port 8000 # specify server port +./chat --show-think # show thinking tokens +./chat --resume # resume previous session +``` + +### Custom system prompt + +Place a file at `~/.flash-moe/system.md` to override the default system prompt used by the serve mode. + +### MoE benchmark (`metal_infer`) + +```bash +make run # single expert forward pass +make verify # Metal vs CPU reference verification +make bench # benchmark single expert (10 iterations) +make moe # full MoE forward pass (K experts, single layer) +make full # full 60-layer forward pass (K=4) +make fullbench # benchmark full 60-layer forward (3 iterations) +``` + +## Code Architecture + +Three Objective-C files, one Metal shader file, one C header — no frameworks, no dependencies beyond Apple system libraries. + +- **`infer.m`** (~7000 lines) — The entire inference engine in one file: model loading, Metal pipeline setup, all 60-layer forward pass, tokenization, sampling, HTTP server (OpenAI-compatible SSE), tool calling, KV cache management. This is the core of the project. +- **`shaders.metal`** (~1200 lines) — All Metal compute kernels: 4-bit/2-bit dequant matvec (multiple optimization levels), SwiGLU, RMS norm, attention (Q@K^T, softmax, scores@V), RoPE, MoE combine+residual. +- **`chat.m`** — Thin HTTP/SSE client with linenoise line editing. Connects to the `--serve` mode of `infer`. No model logic. +- **`main.m`** — Standalone MoE benchmark. Tests expert forward pass in isolation, verifies Metal vs CPU. +- **`tokenizer.h`** — Single-header C BPE tokenizer (449 lines). + +### Key design constraints + +- **Single-file engine**: All inference logic lives in `infer.m`. This is intentional — the entire forward pass, server, and tool calling in one file for simplicity. +- **No custom caching**: Expert data relies entirely on the OS page cache ("Trust the OS"). Every custom cache we tried was slower. +- **Serial GPU→SSD→GPU pipeline**: On Apple Silicon unified memory, SSD DMA and GPU compute share the memory controller. Overlapping them causes GPU latency spikes. The serial pipeline is hardware-optimal. +- **Metal shaders compile at runtime** via `MTLDevice newLibraryWithSource:`. No offline `.metallib` needed (though `make metallib` exists as an option). + +### Per-layer pipeline (3 command buffers) + +``` +CMD3(prev) → CMD1: attention projections + delta-net [GPU] + → CPU: flush results + → CMD2: o_proj + norm + routing + shared [GPU] + → CPU: softmax + topK routing + → I/O: parallel pread K=4 experts [SSD] + → CMD3: expert forward + combine + norm [GPU, DEFERRED] +``` + +CMD3 is submitted without waiting (deferred). The GPU serializes CMD3(N-1) then CMD1(N) via queue ordering. + +--- + # Flash-MoE: Running a 397B Parameter Model on a Laptop > **[Read the paper](paper/flash_moe.pdf)** — Full technical details, 90+ experiments, and the story of how an AI and a human built this in 24 hours. diff --git a/bench_q4k.cu b/bench_q4k.cu new file mode 100644 index 0000000..b933b3d --- /dev/null +++ b/bench_q4k.cu @@ -0,0 +1,156 @@ +/* + * bench_q4k.cu — Compare MLX affine 4-bit vs GGML Q4_K kernel performance + * + * Build: + * nvcc -O2 -o bench_q4k bench_q4k.cu -lpthread + */ + +#include +#include +#include +#include +#include + +// Need ROWS_PER_BLOCK and GROUP_SIZE before including kernels +#define GROUP_SIZE 64 +#include "kernels.cuh" + +#define CHECK_CUDA(call) do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(err)); exit(1); \ + } \ +} while(0) + +static inline float bf16_to_f32_h(uint16_t bf16) { + uint32_t tmp = (uint32_t)bf16 << 16; + float f; memcpy(&f, &tmp, sizeof(f)); return f; +} + +int main() { + // Test dimensions matching expert projections + struct { int out_dim; int in_dim; const char *name; } tests[] = { + {1024, 4096, "gate/up_proj"}, + {4096, 1024, "down_proj"}, + {512, 4096, "routing"}, + {248320, 4096, "lm_head"}, + }; + + cudaDeviceProp prop; + CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + printf("GPU: %s, %d SMs, %.0f GB/s\n", prop.name, + prop.multiProcessorCount, + prop.memoryBusWidth / 8.0 * prop.memoryClockRate * 2.0 / 1e6); + + int iters = 200; + + for (int t = 0; t < 4; t++) { + int out_dim = tests[t].out_dim; + int in_dim = tests[t].in_dim; + printf("\n=== %s [%d, %d] ===\n", tests[t].name, out_dim, in_dim); + + // Allocate input vector + float *h_x = (float *)malloc(in_dim * sizeof(float)); + for (int i = 0; i < in_dim; i++) h_x[i] = (float)(rand() % 1000) / 1000.0f - 0.5f; + float *d_x, *d_out; + CHECK_CUDA(cudaMalloc(&d_x, in_dim * sizeof(float))); + CHECK_CUDA(cudaMalloc(&d_out, out_dim * sizeof(float))); + CHECK_CUDA(cudaMemcpy(d_x, h_x, in_dim * sizeof(float), cudaMemcpyHostToDevice)); + + // ---- MLX format ---- + uint32_t packed_cols = in_dim / 8; + uint32_t num_groups = in_dim / GROUP_SIZE; + size_t mlx_w_sz = out_dim * packed_cols * sizeof(uint32_t); + size_t mlx_s_sz = out_dim * num_groups * sizeof(uint16_t); + + uint32_t *h_W = (uint32_t *)malloc(mlx_w_sz); + uint16_t *h_S = (uint16_t *)malloc(mlx_s_sz); + uint16_t *h_B = (uint16_t *)malloc(mlx_s_sz); + for (size_t i = 0; i < out_dim * packed_cols; i++) h_W[i] = rand(); + for (size_t i = 0; i < out_dim * num_groups; i++) { + float sv = 0.01f; uint32_t tmp; memcpy(&tmp, &sv, 4); h_S[i] = tmp >> 16; + float bv = -0.5f; memcpy(&tmp, &bv, 4); h_B[i] = tmp >> 16; + } + + uint32_t *d_W; uint16_t *d_S, *d_B; + CHECK_CUDA(cudaMalloc(&d_W, mlx_w_sz)); + CHECK_CUDA(cudaMalloc(&d_S, mlx_s_sz)); + CHECK_CUDA(cudaMalloc(&d_B, mlx_s_sz)); + CHECK_CUDA(cudaMemcpy(d_W, h_W, mlx_w_sz, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(d_S, h_S, mlx_s_sz, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(d_B, h_B, mlx_s_sz, cudaMemcpyHostToDevice)); + + // Warmup + launch_dequant_matvec(d_W, d_S, d_B, d_x, d_out, out_dim, in_dim); + CHECK_CUDA(cudaDeviceSynchronize()); + + cudaEvent_t start, stop; + CHECK_CUDA(cudaEventCreate(&start)); + CHECK_CUDA(cudaEventCreate(&stop)); + + CHECK_CUDA(cudaEventRecord(start)); + for (int i = 0; i < iters; i++) + launch_dequant_matvec(d_W, d_S, d_B, d_x, d_out, out_dim, in_dim); + CHECK_CUDA(cudaEventRecord(stop)); + CHECK_CUDA(cudaEventSynchronize(stop)); + float mlx_ms; + CHECK_CUDA(cudaEventElapsedTime(&mlx_ms, start, stop)); + mlx_ms /= iters; + + size_t mlx_total = mlx_w_sz + mlx_s_sz * 2; + printf(" MLX affine 4-bit: %.3f ms (%.1f GB/s, data=%.1f MB)\n", + mlx_ms, mlx_total / (mlx_ms / 1000.0) / 1e9, mlx_total / 1e6); + + // ---- Q4_K format ---- + uint32_t blocks_per_row = in_dim / QK_K; + size_t q4k_row_sz = blocks_per_row * Q4_K_BLOCK_SIZE; + size_t q4k_total = (size_t)out_dim * q4k_row_sz; + + uint8_t *h_Q4K = (uint8_t *)malloc(q4k_total); + // Fill with synthetic Q4_K data + for (size_t row = 0; row < (size_t)out_dim; row++) { + for (uint32_t bi = 0; bi < blocks_per_row; bi++) { + uint8_t *block = h_Q4K + row * q4k_row_sz + bi * Q4_K_BLOCK_SIZE; + __half d_val = __float2half(0.01f); + __half dmin_val = __float2half(0.005f); + memcpy(block, &d_val, 2); + memcpy(block + 2, &dmin_val, 2); + for (int i = 0; i < 12; i++) block[4 + i] = rand() & 0x3F; + for (int i = 0; i < 128; i++) block[16 + i] = rand(); + } + } + + uint8_t *d_Q4K; + CHECK_CUDA(cudaMalloc(&d_Q4K, q4k_total)); + CHECK_CUDA(cudaMemcpy(d_Q4K, h_Q4K, q4k_total, cudaMemcpyHostToDevice)); + + // Warmup + launch_dequant_matvec_q4k(d_Q4K, d_x, d_out, out_dim, in_dim); + CHECK_CUDA(cudaDeviceSynchronize()); + + CHECK_CUDA(cudaEventRecord(start)); + for (int i = 0; i < iters; i++) + launch_dequant_matvec_q4k(d_Q4K, d_x, d_out, out_dim, in_dim); + CHECK_CUDA(cudaEventRecord(stop)); + CHECK_CUDA(cudaEventSynchronize(stop)); + float q4k_ms; + CHECK_CUDA(cudaEventElapsedTime(&q4k_ms, start, stop)); + q4k_ms /= iters; + + printf(" GGML Q4_K: %.3f ms (%.1f GB/s, data=%.1f MB)\n", + q4k_ms, q4k_total / (q4k_ms / 1000.0) / 1e9, q4k_total / 1e6); + + float ratio = q4k_ms / mlx_ms; + printf(" Ratio Q4_K/MLX: %.2fx %s\n", ratio, + ratio < 1.05 ? "(comparable)" : ratio < 1.2 ? "(slightly slower)" : "(slower)"); + + CHECK_CUDA(cudaFree(d_W)); CHECK_CUDA(cudaFree(d_S)); CHECK_CUDA(cudaFree(d_B)); + CHECK_CUDA(cudaFree(d_Q4K)); + free(h_W); free(h_S); free(h_B); free(h_Q4K); + CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(stop)); + CHECK_CUDA(cudaFree(d_x)); CHECK_CUDA(cudaFree(d_out)); + free(h_x); + } + + return 0; +} diff --git a/bench_transfer.cu b/bench_transfer.cu new file mode 100644 index 0000000..d2e14f6 --- /dev/null +++ b/bench_transfer.cu @@ -0,0 +1,721 @@ +/* + * bench_transfer.cu — Benchmark SSD→CPU→GPU and SSD→GPU (GDS) transfer paths + * + * Tests the critical data paths for MoE expert streaming on NVIDIA: + * 1. pread() SSD → CPU RAM (single, cold cache) + * 2. cudaMemcpy CPU → GPU (PCIe transfer) + * 3. pread() + cudaMemcpy end-to-end + * 4. cuFileRead SSD → GPU (GPUDirect Storage) + * 5. Parallel pread K=4 (cold cache) + * 6. Parallel pread K=4 + cudaMemcpy (full cold pipeline) + * 7. Parallel cuFileRead K=4 → GPU (GDS parallel) + * 8. Warm cache: pread K=4 + cudaMemcpy (page cache hits) + * 9. 4-bit FMA dequant matvec CUDA kernel (GPU compute benchmark) + * + * Build: + * nvcc -O2 -o bench_transfer bench_transfer.cu -lcufile -lpthread + * + * Run: + * ./bench_transfer [--file /path/to/testfile] [--size 7077888] [--iters 50] + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Expert layout matching Flash-MoE Qwen3.5-397B +#define EXPERT_SIZE 7077888 // bytes per expert (from packed layout) +#define K_EXPERTS 4 // active experts per layer +#define NUM_LAYERS 60 +#define DEFAULT_ITERS 50 +#define ALIGN 4096 // O_DIRECT alignment + +// Expert projection dimensions +#define HIDDEN_DIM 4096 +#define MoE_INTERMEDIATE 1024 +#define GROUP_SIZE 64 + +#define CHECK_CUDA(call) do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(1); \ + } \ +} while(0) + +static double now_ms(void) { + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; +} + +static void drop_caches(void) { + (void)system("sync; echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null 2>&1"); +} + +static void create_test_file(const char *path, size_t total_size) { + int fd = open(path, O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (fd < 0) { perror("create_test_file"); exit(1); } + size_t chunk = 1024 * 1024; + char *buf = (char *)malloc(chunk); + for (size_t i = 0; i < chunk; i++) buf[i] = (char)(i & 0xFF); + size_t written = 0; + while (written < total_size) { + size_t w = (total_size - written < chunk) ? total_size - written : chunk; + ssize_t r = write(fd, buf, w); + if (r < 0) { perror("write"); exit(1); } + written += r; + } + close(fd); + free(buf); + drop_caches(); +} + +// ============================================================================ +// BFloat16 helpers (host-side) +// ============================================================================ + +static inline float bf16_to_f32(uint16_t bf16) { + uint32_t tmp = (uint32_t)bf16 << 16; + float f; + memcpy(&f, &tmp, sizeof(f)); + return f; +} + +// ============================================================================ +// CUDA Kernel: 4-bit FMA dequant matvec (port of Metal v3 kernel) +// ============================================================================ +// +// Quantization format (MLX affine 4-bit, group_size=64): +// - Weights: uint32, each holding 8 x 4-bit values +// - Per-group scale and bias in bfloat16 +// - Dequant: value = nibble * scale + bias +// +// FMA optimization: (nibble * scale + bias) * x = fma(nibble, scale*x, bias*x) +// Pre-compute scale*x and bias*x per element, then use __fmaf_rn. +// +// Thread layout: blockDim.x = 32 (one warp per row), blockDim.y = ROWS_PER_BLOCK +// Each warp processes one output row. Lane k handles packed columns k, k+32, k+64... +// Warp shuffle reduction to sum across lanes. + +#define ROWS_PER_BLOCK 8 + +__device__ __forceinline__ float device_bf16_to_f32(uint16_t bf16) { + return __uint_as_float((uint32_t)bf16 << 16); +} + +__global__ void dequant_matvec_4bit_fma( + const uint32_t* __restrict__ W_packed, // [out_dim, in_dim/8] + const uint16_t* __restrict__ scales, // [out_dim, num_groups] bf16 + const uint16_t* __restrict__ biases, // [out_dim, num_groups] bf16 + const float* __restrict__ x, // [in_dim] + float* __restrict__ out, // [out_dim] + uint32_t out_dim, + uint32_t in_dim +) { + // Shared memory cache for input vector + extern __shared__ float x_shared[]; + + const uint32_t lane = threadIdx.x; // 0..31 (warp lane) + const uint32_t warp_id = threadIdx.y; // 0..ROWS_PER_BLOCK-1 + const uint32_t row = blockIdx.x * ROWS_PER_BLOCK + warp_id; + + const uint32_t packed_cols = in_dim / 8; + const uint32_t num_groups = in_dim / GROUP_SIZE; + const uint32_t packed_per_group = GROUP_SIZE / 8; // 8 + + // Cooperative load of input vector into shared memory + const uint32_t total_threads = blockDim.x * blockDim.y; + const uint32_t tid = warp_id * 32 + lane; + for (uint32_t i = tid; i < in_dim; i += total_threads) { + x_shared[i] = x[i]; + } + __syncthreads(); + + if (row >= out_dim) return; + + const uint32_t* w_row = W_packed + row * packed_cols; + const uint16_t* s_row = scales + row * num_groups; + const uint16_t* b_row = biases + row * num_groups; + + float acc = 0.0f; + + // Each lane handles columns: lane, lane+32, lane+64, ... + // Adjacent lanes read adjacent uint32 words → coalesced + for (uint32_t col = lane; col < packed_cols; col += 32) { + uint32_t g = col / packed_per_group; + float scale = device_bf16_to_f32(s_row[g]); + float bias = device_bf16_to_f32(b_row[g]); + + uint32_t packed = w_row[col]; + uint32_t x_base = col * 8; + + // FMA optimization: (nibble * scale + bias) * x = fma(nibble, scale*x, bias*x) + float sx0 = scale * x_shared[x_base + 0]; float bx0 = bias * x_shared[x_base + 0]; + float sx1 = scale * x_shared[x_base + 1]; float bx1 = bias * x_shared[x_base + 1]; + float sx2 = scale * x_shared[x_base + 2]; float bx2 = bias * x_shared[x_base + 2]; + float sx3 = scale * x_shared[x_base + 3]; float bx3 = bias * x_shared[x_base + 3]; + float sx4 = scale * x_shared[x_base + 4]; float bx4 = bias * x_shared[x_base + 4]; + float sx5 = scale * x_shared[x_base + 5]; float bx5 = bias * x_shared[x_base + 5]; + float sx6 = scale * x_shared[x_base + 6]; float bx6 = bias * x_shared[x_base + 6]; + float sx7 = scale * x_shared[x_base + 7]; float bx7 = bias * x_shared[x_base + 7]; + + acc += __fmaf_rn((float)((packed >> 0) & 0xF), sx0, bx0); + acc += __fmaf_rn((float)((packed >> 4) & 0xF), sx1, bx1); + acc += __fmaf_rn((float)((packed >> 8) & 0xF), sx2, bx2); + acc += __fmaf_rn((float)((packed >> 12) & 0xF), sx3, bx3); + acc += __fmaf_rn((float)((packed >> 16) & 0xF), sx4, bx4); + acc += __fmaf_rn((float)((packed >> 20) & 0xF), sx5, bx5); + acc += __fmaf_rn((float)((packed >> 24) & 0xF), sx6, bx6); + acc += __fmaf_rn((float)((packed >> 28) & 0xF), sx7, bx7); + } + + // Warp reduction (sum across 32 lanes) + for (int offset = 16; offset > 0; offset >>= 1) { + acc += __shfl_down_sync(0xFFFFFFFF, acc, offset); + } + + if (lane == 0) { + out[row] = acc; + } +} + +// ============================================================================ +// Transfer benchmarks +// ============================================================================ + +typedef struct { + int fd; + size_t size; + off_t offset; + void *buf; +} PreadArg; + +static void *pread_thread(void *arg) { + PreadArg *a = (PreadArg *)arg; + (void)pread(a->fd, a->buf, a->size, a->offset); + return NULL; +} + +// Test 1: pread SSD → CPU +static double bench_pread_cpu(int fd, size_t size, int iters) { + void *buf; + (void)posix_memalign(&buf, ALIGN, size); + double t0 = now_ms(); + for (int i = 0; i < iters; i++) { + off_t offset = (off_t)((i % 128) * size); + pread(fd, buf, size, offset); + } + double elapsed = now_ms() - t0; + free(buf); + return elapsed / iters; +} + +// Test 2: cudaMemcpy CPU → GPU +static double bench_cudamemcpy(size_t size, int iters) { + void *h_buf, *d_buf; + CHECK_CUDA(cudaMallocHost(&h_buf, size)); + CHECK_CUDA(cudaMalloc(&d_buf, size)); + memset(h_buf, 0xAB, size); + CHECK_CUDA(cudaMemcpy(d_buf, h_buf, size, cudaMemcpyHostToDevice)); // warmup + double t0 = now_ms(); + for (int i = 0; i < iters; i++) + CHECK_CUDA(cudaMemcpy(d_buf, h_buf, size, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaDeviceSynchronize()); + double elapsed = now_ms() - t0; + CHECK_CUDA(cudaFree(d_buf)); + CHECK_CUDA(cudaFreeHost(h_buf)); + return elapsed / iters; +} + +// Test 3: pread + cudaMemcpy +static double bench_pread_then_cuda(int fd, size_t size, int iters) { + void *h_buf, *d_buf; + CHECK_CUDA(cudaMallocHost(&h_buf, size)); + CHECK_CUDA(cudaMalloc(&d_buf, size)); + double t0 = now_ms(); + for (int i = 0; i < iters; i++) { + pread(fd, h_buf, size, (off_t)((i % 128) * size)); + CHECK_CUDA(cudaMemcpy(d_buf, h_buf, size, cudaMemcpyHostToDevice)); + } + CHECK_CUDA(cudaDeviceSynchronize()); + double elapsed = now_ms() - t0; + CHECK_CUDA(cudaFree(d_buf)); + CHECK_CUDA(cudaFreeHost(h_buf)); + return elapsed / iters; +} + +// Test 4: GDS cuFileRead (single) +static double bench_gds(const char *path, size_t size, int iters) { + CUfileError_t status = cuFileDriverOpen(); + if (status.err != CU_FILE_SUCCESS) { + fprintf(stderr, " cuFileDriverOpen failed: %d\n", status.err); + return -1.0; + } + int fd = open(path, O_RDONLY | O_DIRECT); + if (fd < 0) { cuFileDriverClose(); return -1.0; } + + CUfileDescr_t desc = {}; + desc.handle.fd = fd; + desc.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + CUfileHandle_t handle; + status = cuFileHandleRegister(&handle, &desc); + if (status.err != CU_FILE_SUCCESS) { close(fd); cuFileDriverClose(); return -1.0; } + + void *d_buf; + CHECK_CUDA(cudaMalloc(&d_buf, size)); + cuFileBufRegister(d_buf, size, 0); + + ssize_t ret = cuFileRead(handle, d_buf, size, 0, 0); // warmup + if (ret < 0) { + fprintf(stderr, " cuFileRead failed: %zd\n", ret); + cuFileBufDeregister(d_buf); CHECK_CUDA(cudaFree(d_buf)); + cuFileHandleDeregister(handle); close(fd); cuFileDriverClose(); + return -1.0; + } + + double t0 = now_ms(); + for (int i = 0; i < iters; i++) { + off_t offset = (off_t)(((i % 128) * size) / ALIGN * ALIGN); + cuFileRead(handle, d_buf, size, offset, 0); + } + CHECK_CUDA(cudaDeviceSynchronize()); + double elapsed = now_ms() - t0; + + cuFileBufDeregister(d_buf); CHECK_CUDA(cudaFree(d_buf)); + cuFileHandleDeregister(handle); close(fd); cuFileDriverClose(); + return elapsed / iters; +} + +// Test 5: Parallel pread K=4 +static double bench_parallel_pread(int fd, size_t size, int k, int iters) { + void *bufs[8]; + pthread_t threads[8]; + PreadArg args[8]; + for (int i = 0; i < k; i++) (void)posix_memalign(&bufs[i], ALIGN, size); + double t0 = now_ms(); + for (int iter = 0; iter < iters; iter++) { + for (int i = 0; i < k; i++) { + args[i] = (PreadArg){fd, size, (off_t)(((iter*k+i) % 128) * size), bufs[i]}; + pthread_create(&threads[i], NULL, pread_thread, &args[i]); + } + for (int i = 0; i < k; i++) pthread_join(threads[i], NULL); + } + double elapsed = now_ms() - t0; + for (int i = 0; i < k; i++) free(bufs[i]); + return elapsed / iters; +} + +// Test 6: Parallel pread K=4 + cudaMemcpy +static double bench_parallel_pread_cuda(int fd, size_t size, int k, int iters) { + void *h_bufs[8], *d_bufs[8]; + cudaStream_t streams[8]; + for (int i = 0; i < k; i++) { + CHECK_CUDA(cudaMallocHost(&h_bufs[i], size)); + CHECK_CUDA(cudaMalloc(&d_bufs[i], size)); + CHECK_CUDA(cudaStreamCreate(&streams[i])); + } + pthread_t threads[8]; + PreadArg args[8]; + + double t0 = now_ms(); + for (int iter = 0; iter < iters; iter++) { + for (int i = 0; i < k; i++) { + args[i] = (PreadArg){fd, size, (off_t)(((iter*k+i) % 128) * size), h_bufs[i]}; + pthread_create(&threads[i], NULL, pread_thread, &args[i]); + } + for (int i = 0; i < k; i++) pthread_join(threads[i], NULL); + for (int i = 0; i < k; i++) + CHECK_CUDA(cudaMemcpyAsync(d_bufs[i], h_bufs[i], size, cudaMemcpyHostToDevice, streams[i])); + CHECK_CUDA(cudaDeviceSynchronize()); + } + double elapsed = now_ms() - t0; + for (int i = 0; i < k; i++) { + CHECK_CUDA(cudaStreamDestroy(streams[i])); + CHECK_CUDA(cudaFree(d_bufs[i])); + CHECK_CUDA(cudaFreeHost(h_bufs[i])); + } + return elapsed / iters; +} + +// Test 7: Parallel GDS cuFileRead K=4 +typedef struct { + CUfileHandle_t handle; + void *d_buf; + size_t size; + off_t offset; +} GDSReadArg; + +static void *gds_read_thread(void *arg) { + GDSReadArg *a = (GDSReadArg *)arg; + cuFileRead(a->handle, a->d_buf, a->size, a->offset, 0); + return NULL; +} + +static double bench_parallel_gds(const char *path, size_t size, int k, int iters) { + CUfileError_t status = cuFileDriverOpen(); + if (status.err != CU_FILE_SUCCESS) return -1.0; + + int fd = open(path, O_RDONLY | O_DIRECT); + if (fd < 0) { cuFileDriverClose(); return -1.0; } + + CUfileDescr_t desc = {}; + desc.handle.fd = fd; + desc.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + CUfileHandle_t handle; + status = cuFileHandleRegister(&handle, &desc); + if (status.err != CU_FILE_SUCCESS) { close(fd); cuFileDriverClose(); return -1.0; } + + void *d_bufs[8]; + for (int i = 0; i < k; i++) { + CHECK_CUDA(cudaMalloc(&d_bufs[i], size)); + cuFileBufRegister(d_bufs[i], size, 0); + } + + // Warmup + cuFileRead(handle, d_bufs[0], size, 0, 0); + + pthread_t threads[8]; + GDSReadArg args[8]; + + double t0 = now_ms(); + for (int iter = 0; iter < iters; iter++) { + for (int i = 0; i < k; i++) { + off_t offset = (off_t)((((iter*k+i) % 128) * size) / ALIGN * ALIGN); + args[i] = (GDSReadArg){handle, d_bufs[i], size, offset}; + pthread_create(&threads[i], NULL, gds_read_thread, &args[i]); + } + for (int i = 0; i < k; i++) pthread_join(threads[i], NULL); + } + CHECK_CUDA(cudaDeviceSynchronize()); + double elapsed = now_ms() - t0; + + for (int i = 0; i < k; i++) { + cuFileBufDeregister(d_bufs[i]); + CHECK_CUDA(cudaFree(d_bufs[i])); + } + cuFileHandleDeregister(handle); close(fd); cuFileDriverClose(); + return elapsed / iters; +} + +// Test 8: Warm cache — read same data repeatedly (page cache hits) +static double bench_warm_pread_cuda(int fd, size_t size, int k, int iters) { + void *h_bufs[8], *d_bufs[8]; + cudaStream_t streams[8]; + for (int i = 0; i < k; i++) { + CHECK_CUDA(cudaMallocHost(&h_bufs[i], size)); + CHECK_CUDA(cudaMalloc(&d_bufs[i], size)); + CHECK_CUDA(cudaStreamCreate(&streams[i])); + } + pthread_t threads[8]; + PreadArg args[8]; + + // Warm up: read the same offsets to populate page cache + for (int i = 0; i < k; i++) { + off_t offset = (off_t)(i * size); + pread(fd, h_bufs[i], size, offset); + } + + double t0 = now_ms(); + for (int iter = 0; iter < iters; iter++) { + // Always read from same offsets → page cache hit + for (int i = 0; i < k; i++) { + args[i] = (PreadArg){fd, size, (off_t)(i * size), h_bufs[i]}; + pthread_create(&threads[i], NULL, pread_thread, &args[i]); + } + for (int i = 0; i < k; i++) pthread_join(threads[i], NULL); + for (int i = 0; i < k; i++) + CHECK_CUDA(cudaMemcpyAsync(d_bufs[i], h_bufs[i], size, cudaMemcpyHostToDevice, streams[i])); + CHECK_CUDA(cudaDeviceSynchronize()); + } + double elapsed = now_ms() - t0; + for (int i = 0; i < k; i++) { + CHECK_CUDA(cudaStreamDestroy(streams[i])); + CHECK_CUDA(cudaFree(d_bufs[i])); + CHECK_CUDA(cudaFreeHost(h_bufs[i])); + } + return elapsed / iters; +} + +// ============================================================================ +// Test 9: CUDA dequant_matvec kernel benchmark +// ============================================================================ +static void bench_dequant_kernel(int iters) { + // Simulate gate_proj: [1024, 4096] (out=1024, in=4096) + uint32_t out_dim = MoE_INTERMEDIATE; // 1024 + uint32_t in_dim = HIDDEN_DIM; // 4096 + uint32_t packed_cols = in_dim / 8; // 512 + uint32_t num_groups = in_dim / GROUP_SIZE; // 64 + + // Allocate and fill with synthetic quantized data + size_t w_size = out_dim * packed_cols * sizeof(uint32_t); + size_t s_size = out_dim * num_groups * sizeof(uint16_t); + + uint32_t *h_W = (uint32_t *)malloc(w_size); + uint16_t *h_s = (uint16_t *)malloc(s_size); + uint16_t *h_b = (uint16_t *)malloc(s_size); + float *h_x = (float *)malloc(in_dim * sizeof(float)); + float *h_out = (float *)malloc(out_dim * sizeof(float)); + + // Fill with realistic-ish data + srand(42); + for (uint32_t i = 0; i < out_dim * packed_cols; i++) + h_W[i] = rand(); + for (uint32_t i = 0; i < out_dim * num_groups; i++) { + // bf16 encoding of small floats + float sv = 0.01f * (rand() % 100) / 100.0f; + float bv = -0.5f + (rand() % 100) / 100.0f; + uint32_t tmp; + memcpy(&tmp, &sv, 4); h_s[i] = (uint16_t)(tmp >> 16); + memcpy(&tmp, &bv, 4); h_b[i] = (uint16_t)(tmp >> 16); + } + for (uint32_t i = 0; i < in_dim; i++) + h_x[i] = -1.0f + 2.0f * (rand() % 10000) / 10000.0f; + + // Device allocations + uint32_t *d_W; uint16_t *d_s, *d_b; float *d_x, *d_out; + CHECK_CUDA(cudaMalloc(&d_W, w_size)); + CHECK_CUDA(cudaMalloc(&d_s, s_size)); + CHECK_CUDA(cudaMalloc(&d_b, s_size)); + CHECK_CUDA(cudaMalloc(&d_x, in_dim * sizeof(float))); + CHECK_CUDA(cudaMalloc(&d_out, out_dim * sizeof(float))); + + CHECK_CUDA(cudaMemcpy(d_W, h_W, w_size, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(d_s, h_s, s_size, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(d_b, h_b, s_size, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(d_x, h_x, in_dim * sizeof(float), cudaMemcpyHostToDevice)); + + // Kernel launch config: 32 threads/warp × ROWS_PER_BLOCK warps per block + dim3 block(32, ROWS_PER_BLOCK); + dim3 grid((out_dim + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); + size_t shared_mem = in_dim * sizeof(float); + + // Warmup + dequant_matvec_4bit_fma<<>>(d_W, d_s, d_b, d_x, d_out, out_dim, in_dim); + CHECK_CUDA(cudaDeviceSynchronize()); + + // Verify: compute on CPU and compare + float *cpu_out = (float *)calloc(out_dim, sizeof(float)); + for (uint32_t row = 0; row < out_dim; row++) { + float acc = 0.0f; + for (uint32_t col = 0; col < packed_cols; col++) { + uint32_t g = col / (GROUP_SIZE / 8); + float scale = bf16_to_f32(h_s[row * num_groups + g]); + float bias = bf16_to_f32(h_b[row * num_groups + g]); + uint32_t packed = h_W[row * packed_cols + col]; + for (int n = 0; n < 8; n++) { + float nibble = (float)((packed >> (n * 4)) & 0xF); + acc += (nibble * scale + bias) * h_x[col * 8 + n]; + } + } + cpu_out[row] = acc; + } + CHECK_CUDA(cudaMemcpy(h_out, d_out, out_dim * sizeof(float), cudaMemcpyDeviceToHost)); + float max_err = 0.0f; + for (uint32_t i = 0; i < out_dim; i++) { + float err = fabsf(h_out[i] - cpu_out[i]); + float rel = (fabsf(cpu_out[i]) > 1e-6f) ? err / fabsf(cpu_out[i]) : err; + if (rel > max_err) max_err = rel; + } + printf(" Verification: max relative error = %.2e %s\n", max_err, + max_err < 1e-3 ? "(OK)" : "(WARNING: large error)"); + free(cpu_out); + + // Benchmark gate_proj [1024, 4096] + cudaEvent_t start, stop; + CHECK_CUDA(cudaEventCreate(&start)); + CHECK_CUDA(cudaEventCreate(&stop)); + + CHECK_CUDA(cudaEventRecord(start)); + for (int i = 0; i < iters; i++) { + dequant_matvec_4bit_fma<<>>(d_W, d_s, d_b, d_x, d_out, out_dim, in_dim); + } + CHECK_CUDA(cudaEventRecord(stop)); + CHECK_CUDA(cudaEventSynchronize(stop)); + float gate_ms = 0; + CHECK_CUDA(cudaEventElapsedTime(&gate_ms, start, stop)); + gate_ms /= iters; + + printf(" gate_proj [%d, %d]: %.3f ms\n", out_dim, in_dim, gate_ms); + + // Benchmark down_proj [4096, 1024] + uint32_t down_out = HIDDEN_DIM; // 4096 + uint32_t down_in = MoE_INTERMEDIATE; // 1024 + uint32_t down_packed = down_in / 8; // 128 + uint32_t down_groups = down_in / GROUP_SIZE; // 16 + size_t dw_size = down_out * down_packed * sizeof(uint32_t); + size_t ds_size = down_out * down_groups * sizeof(uint16_t); + + uint32_t *d_dW; uint16_t *d_ds, *d_db; float *d_dx, *d_dout; + CHECK_CUDA(cudaMalloc(&d_dW, dw_size)); + CHECK_CUDA(cudaMalloc(&d_ds, ds_size)); + CHECK_CUDA(cudaMalloc(&d_db, ds_size)); + CHECK_CUDA(cudaMalloc(&d_dx, down_in * sizeof(float))); + CHECK_CUDA(cudaMalloc(&d_dout, down_out * sizeof(float))); + + dim3 down_grid((down_out + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); + size_t down_shared = down_in * sizeof(float); + + // Warmup + dequant_matvec_4bit_fma<<>>(d_dW, d_ds, d_db, d_dx, d_dout, down_out, down_in); + CHECK_CUDA(cudaDeviceSynchronize()); + + CHECK_CUDA(cudaEventRecord(start)); + for (int i = 0; i < iters; i++) { + dequant_matvec_4bit_fma<<>>(d_dW, d_ds, d_db, d_dx, d_dout, down_out, down_in); + } + CHECK_CUDA(cudaEventRecord(stop)); + CHECK_CUDA(cudaEventSynchronize(stop)); + float down_ms = 0; + CHECK_CUDA(cudaEventElapsedTime(&down_ms, start, stop)); + down_ms /= iters; + + printf(" down_proj [%d, %d]: %.3f ms\n", down_out, down_in, down_ms); + + // Full expert forward: gate + up + SwiGLU + down (K=1) + float expert_ms = gate_ms * 2 + down_ms; // gate + up (same dims) + down + printf(" Full expert (gate+up+down): %.3f ms\n", expert_ms); + printf(" K=%d experts: %.3f ms\n", K_EXPERTS, expert_ms * K_EXPERTS); + + CHECK_CUDA(cudaEventDestroy(start)); + CHECK_CUDA(cudaEventDestroy(stop)); + CHECK_CUDA(cudaFree(d_W)); CHECK_CUDA(cudaFree(d_s)); CHECK_CUDA(cudaFree(d_b)); + CHECK_CUDA(cudaFree(d_x)); CHECK_CUDA(cudaFree(d_out)); + CHECK_CUDA(cudaFree(d_dW)); CHECK_CUDA(cudaFree(d_ds)); CHECK_CUDA(cudaFree(d_db)); + CHECK_CUDA(cudaFree(d_dx)); CHECK_CUDA(cudaFree(d_dout)); + free(h_W); free(h_s); free(h_b); free(h_x); free(h_out); +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char **argv) { + const char *testfile = "/tmp/flash_moe_bench.dat"; + size_t expert_size = EXPERT_SIZE; + int iters = DEFAULT_ITERS; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--file") == 0 && i+1 < argc) testfile = argv[++i]; + else if (strcmp(argv[i], "--size") == 0 && i+1 < argc) expert_size = atol(argv[++i]); + else if (strcmp(argv[i], "--iters") == 0 && i+1 < argc) iters = atoi(argv[++i]); + } + + printf("=== Flash-MoE NVIDIA Transfer + Compute Benchmark ===\n"); + printf("Expert size: %.2f MB, K=%d, %d iterations\n\n", + expert_size / (1024.0 * 1024.0), K_EXPERTS, iters); + + size_t file_size = expert_size * 256; + printf("Creating test file (%zu MB)...\n", file_size / (1024*1024)); + create_test_file(testfile, file_size); + + int fd = open(testfile, O_RDONLY); + if (fd < 0) { perror("open"); return 1; } + + cudaDeviceProp prop; + CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + printf("GPU: %s, VRAM: %zu MB, SM: %d, Mem BW: %.0f GB/s\n\n", + prop.name, prop.totalGlobalMem / (1024*1024), + prop.multiProcessorCount, + prop.memoryBusWidth / 8.0 * prop.memoryClockRate * 2.0 / 1e6); + + double ms; + + // Test 1 + printf("Test 1: pread() SSD → CPU (1x %.1fMB, cold)\n", expert_size/1048576.0); + drop_caches(); + ms = bench_pread_cpu(fd, expert_size, iters); + printf(" %.2f ms (%.2f GB/s)\n\n", ms, (expert_size / (ms/1000.0)) / 1e9); + + // Test 2 + printf("Test 2: cudaMemcpy CPU → GPU (pinned, %.1fMB)\n", expert_size/1048576.0); + ms = bench_cudamemcpy(expert_size, iters); + printf(" %.2f ms (%.2f GB/s)\n\n", ms, (expert_size / (ms/1000.0)) / 1e9); + + // Test 3 + printf("Test 3: pread + cudaMemcpy (cold, %.1fMB)\n", expert_size/1048576.0); + drop_caches(); + ms = bench_pread_then_cuda(fd, expert_size, iters); + printf(" %.2f ms (%.2f GB/s)\n\n", ms, (expert_size / (ms/1000.0)) / 1e9); + + // Test 4 + printf("Test 4: GDS cuFileRead (1x %.1fMB, cold)\n", expert_size/1048576.0); + close(fd); + drop_caches(); + ms = bench_gds(testfile, expert_size, iters); + if (ms > 0) printf(" %.2f ms (%.2f GB/s)\n\n", ms, (expert_size / (ms/1000.0)) / 1e9); + else printf(" (not available)\n\n"); + + fd = open(testfile, O_RDONLY); + + // Test 5 + printf("Test 5: Parallel pread K=%d → CPU (cold)\n", K_EXPERTS); + drop_caches(); + ms = bench_parallel_pread(fd, expert_size, K_EXPERTS, iters); + printf(" %.2f ms (%.2f GB/s agg)\n\n", ms, (K_EXPERTS*expert_size / (ms/1000.0)) / 1e9); + + // Test 6 + printf("Test 6: Parallel pread K=%d + cudaMemcpy (cold, full pipeline)\n", K_EXPERTS); + drop_caches(); + ms = bench_parallel_pread_cuda(fd, expert_size, K_EXPERTS, iters); + double cold_pipeline_ms = ms; + printf(" %.2f ms (%.2f GB/s)\n\n", ms, (K_EXPERTS*expert_size / (ms/1000.0)) / 1e9); + + // Test 7 + printf("Test 7: Parallel GDS cuFileRead K=%d → GPU (cold)\n", K_EXPERTS); + close(fd); + drop_caches(); + ms = bench_parallel_gds(testfile, expert_size, K_EXPERTS, iters); + double gds_pipeline_ms = ms; + if (ms > 0) printf(" %.2f ms (%.2f GB/s)\n\n", ms, (K_EXPERTS*expert_size / (ms/1000.0)) / 1e9); + else printf(" (not available)\n\n"); + + fd = open(testfile, O_RDONLY); + + // Test 8 + printf("Test 8: Warm cache pread K=%d + cudaMemcpy (page cache hits)\n", K_EXPERTS); + ms = bench_warm_pread_cuda(fd, expert_size, K_EXPERTS, iters); + double warm_pipeline_ms = ms; + printf(" %.2f ms (%.2f GB/s)\n\n", ms, (K_EXPERTS*expert_size / (ms/1000.0)) / 1e9); + + // Test 9 + printf("Test 9: CUDA 4-bit FMA dequant matvec kernel\n"); + bench_dequant_kernel(iters * 10); // more iters for GPU kernel + printf("\n"); + + // Summary + printf("========================================\n"); + printf("=== Per-Token Estimate (60 layers) ===\n"); + printf("========================================\n"); + printf(" Cold cache (pread+cuda): %.1f ms/layer → %.0f ms/tok → %.2f tok/s\n", + cold_pipeline_ms, cold_pipeline_ms * NUM_LAYERS, 1000.0 / (cold_pipeline_ms * NUM_LAYERS)); + if (gds_pipeline_ms > 0) + printf(" Cold cache (GDS K=%d): %.1f ms/layer → %.0f ms/tok → %.2f tok/s\n", + K_EXPERTS, gds_pipeline_ms, gds_pipeline_ms * NUM_LAYERS, 1000.0 / (gds_pipeline_ms * NUM_LAYERS)); + printf(" Warm cache (page cache): %.1f ms/layer → %.0f ms/tok → %.2f tok/s\n", + warm_pipeline_ms, warm_pipeline_ms * NUM_LAYERS, 1000.0 / (warm_pipeline_ms * NUM_LAYERS)); + + // Mixed estimate: ~30% cache hit rate with 64GB RAM + double mixed_ms = 0.7 * cold_pipeline_ms + 0.3 * warm_pipeline_ms; + printf(" Mixed (~30%% cache hit): %.1f ms/layer → %.0f ms/tok → %.2f tok/s\n", + mixed_ms, mixed_ms * NUM_LAYERS, 1000.0 / (mixed_ms * NUM_LAYERS)); + + double best_cold = (gds_pipeline_ms > 0 && gds_pipeline_ms < cold_pipeline_ms) ? gds_pipeline_ms : cold_pipeline_ms; + double mixed_best = 0.7 * best_cold + 0.3 * warm_pipeline_ms; + printf(" Mixed (best cold path): %.1f ms/layer → %.0f ms/tok → %.2f tok/s\n", + mixed_best, mixed_best * NUM_LAYERS, 1000.0 / (mixed_best * NUM_LAYERS)); + + printf("\n (GPU compute adds ~0.1-0.3ms/layer on RTX 4090 — see Test 9)\n"); + + close(fd); + unlink(testfile); + return 0; +} diff --git a/cuda_infer/Makefile b/cuda_infer/Makefile new file mode 100644 index 0000000..7b39ef4 --- /dev/null +++ b/cuda_infer/Makefile @@ -0,0 +1,18 @@ +NVCC = /usr/local/cuda-12.8/bin/nvcc +CFLAGS = -O2 -Wno-deprecated-gpu-targets +CUDA_LIB = /usr/local/cuda-12.8/targets/x86_64-linux/lib +CUDA_INC = /usr/local/cuda-12.8/targets/x86_64-linux/include +LDFLAGS = -lpthread -L$(CUDA_LIB) -lcufile + +TARGET = infer + +tokenizer_impl.o: tokenizer_impl.c ../metal_infer/tokenizer.h + gcc -O2 -c tokenizer_impl.c -o tokenizer_impl.o + +$(TARGET): infer.cu kernels.cuh ../metal_infer/tokenizer.h tokenizer_impl.o + $(NVCC) $(CFLAGS) -o $(TARGET) infer.cu tokenizer_impl.o $(LDFLAGS) + +clean: + rm -f $(TARGET) + +.PHONY: clean diff --git a/cuda_infer/README.md b/cuda_infer/README.md new file mode 100644 index 0000000..a2e738f --- /dev/null +++ b/cuda_infer/README.md @@ -0,0 +1,330 @@ +# Flash-MoE CUDA: Running Qwen3.5-397B on a Single NVIDIA GPU + +CUDA/C port of [Flash-MoE](../CLAUDE.md) for x86 PCs with NVIDIA GPUs. Runs **Qwen3.5-397B-A17B** (397 billion parameter MoE model) on a single RTX 4090 with 24GB VRAM, streaming 209GB of expert weights from NVMe SSD. + +**5.35 tokens/second** (avg on RTX 4090, 5.86 peak) with tool calling, OpenAI-compatible API, and SSE streaming. No Python. No frameworks. One CUDA file + one kernel header. + +## How It Works + +The full model is 209GB at 4-bit quantization. Only 5.2GB of non-expert weights fit in GPU VRAM. The remaining 203GB of expert weights (512 experts per layer, K=4 activated per token) stream from NVMe SSD on demand: + +``` +SSD (203GB experts) ──pread──> CPU RAM (page cache) ──cudaMemcpy──> GPU VRAM + ↕ + VRAM expert cache (17GB, ~2500 experts) + ↕ + CUDA kernels +``` + +Each token requires loading 4 experts × 60 layers = 240 expert reads (~6.75MB each). A three-tier caching hierarchy minimizes SSD access: + +1. **VRAM expert cache** (~17GB, ~2500 experts): LRU cache in GPU memory. Hot experts are served instantly without any I/O. After a few requests, ~95% of expert accesses hit the VRAM cache. +2. **OS page cache** (~50GB with 64GB RAM): Experts not in VRAM are read via `pread()`, which populates the OS page cache. Repeat accesses hit RAM at ~10 GB/s. +3. **NVMe SSD**: Cold misses go to SSD at ~5-7 GB/s. + +The VRAM cache warms progressively: cold → warm → hot over a few requests. GDS (direct NVMe-to-GPU DMA) is available for low-RAM systems via `ENABLE_GDS=1` but bypasses the page cache, so it's slower for sustained generation. + +## Results + +### Multi-Hardware Benchmarks + +| GPU | VRAM | RAM | Disk | VRAM Cache | Avg tok/s | Peak tok/s | +|-----|------|-----|------|------------|-----------|-----------| +| **RTX 4090** | 24 GB | 64 GB | NVMe 7 GB/s | 2565 experts | **5.35** | **5.86** | +| **RTX 3060** | 12 GB | 755 GB | NVMe 9 GB/s | 840 experts | **2.92** | **3.23** | +| RTX 2080 Ti | 11 GB | 16 GB | virtio 520 MB/s | 647 experts | 0.51 | 0.54 | +| Apple M3 Max | 48 GB unified | — | NVMe 17.5 GB/s | — | 4.36 | — | + +### VRAM Cache Warm-Up (RTX 4090) + +| Request | tok/s | Improvement | +|---------|-------|-------------| +| 1 (cold) | 2.49 | baseline | +| 2 | 3.22 | +29% | +| 4 | 5.25 | +111% | +| 8 (hot) | 5.86 | +135% | + +### Comparison with Other Solutions + +| System | Qwen3.5-397B | Hardware Required | Approach | +|--------|-------------|-------------------|----------| +| **Flash-MoE CUDA** | **5.35 tok/s** | **1x RTX 4090 + 16GB+ RAM + NVMe** | VRAM cache + page cache + SSD | +| KTransformers | ~14 tok/s* | 1x RTX 4090 + **384GB RAM** | CPU expert compute (AMX), GPU attention | +| llama.cpp (offload) | ~1-2 tok/s | 1x RTX 4090 + **256GB RAM** | CPU/GPU layer split, GGUF | +| KTransformers (full) | ~150 tok/s | **4x RTX 4090 + 800GB RAM** | Multi-GPU tensor parallelism | + +*KTransformers single-GPU numbers are for Qwen3-235B (smaller model); 397B numbers not published for single GPU. + +**Key finding**: VRAM size is the dominant performance factor. The RTX 3060 (12GB) with 755GB RAM and fast NVMe is slower than the RTX 4090 (24GB) with 64GB RAM because the smaller VRAM cache (840 vs 2565 experts) means more cache misses requiring SSD or page cache reads. + +**Key advantage**: Flash-MoE CUDA requires only **16GB RAM** (process uses 5.5GB; more RAM = better page cache but not required) vs 256-384GB for alternatives. + +## Hardware Requirements + +- **GPU**: NVIDIA GPU with 16GB+ VRAM (tested on RTX 4090) +- **RAM**: 16GB minimum, 64GB+ recommended (process uses 5.5GB; extra RAM serves as page cache for experts — 64GB caches ~50% of expert data, significantly improving throughput) +- **SSD**: NVMe SSD with 250GB+ free space, PCIe 4.0+ recommended +- **CUDA**: 12.8+ with GDS support (optional but recommended) +- **OS**: Linux (tested on Ubuntu 24.04) + +## Quick Start + +### 1. Build + +```bash +cd cuda_infer +make # requires CUDA toolkit 12.8+ and libcufile +``` + +### 2. Download and prepare model weights + +```bash +# Install Python dependencies +python3 -m venv flash-moe-env +source flash-moe-env/bin/activate +pip install huggingface_hub safetensors numpy + +# Download MLX 4-bit quantized model (~209GB) +python3 -c " +from huggingface_hub import snapshot_download +snapshot_download('mlx-community/Qwen3.5-397B-A17B-4bit', local_dir='model-safetensors') +" + +# Build expert index and repack into per-layer binary files (~203GB) +python3 build_expert_index.py --model model-safetensors --output expert_index.json +python3 ../repack_experts.py --index expert_index.json + +# Extract non-expert weights (~5.2GB) +python3 ../metal_infer/extract_weights.py --model model-safetensors --output . + +# Export tokenizer and vocabulary +python3 ../metal_infer/export_tokenizer.py model-safetensors/tokenizer.json tokenizer.bin +python3 export_vocab.py model-safetensors/tokenizer.json vocab.bin +``` + +### 3. Run + +```bash +# Direct generation +./infer --prompt "Explain quantum computing" --tokens 50 + +# HTTP server (OpenAI-compatible API) +./infer --serve 8080 + +# With timing breakdown +./infer --prompt "Hello" --tokens 20 --timing +``` + +## HTTP Server (OpenAI-Compatible API) + +Start the server with `--serve PORT`: + +```bash +./infer --serve 8080 +``` + +On startup, the server prefills and caches the system prompt (~4s). All subsequent requests restore from this snapshot instantly — no repeated prefill cost. Custom system prompt can be placed at `~/.flash-moe/system.md`. + +### Endpoints + +- `POST /v1/chat/completions` — OpenAI Chat Completions API (SSE streaming) +- `POST /v1/messages` — Anthropic Messages API (SSE streaming) +- `GET /v1/models` — List available models +- `GET /health` — Health check + +Both chat endpoints support tool calling and produce correct streaming events for their respective formats. + +### Basic chat + +```bash +curl -N http://localhost:8080/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 100, + "stream": true + }' +``` + +### Tool calling (function calling) + +The server supports OpenAI-compatible tool calling. Pass `tools` in the request and the model will generate `tool_calls` in the response: + +```bash +curl -N http://localhost:8080/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is the weather in Tokyo?"}], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"] + } + } + }], + "max_tokens": 200, + "stream": true + }' +``` + +Response includes tool calls in OpenAI format: + +```json +{"choices": [{"delta": {"tool_calls": [{"id": "call_1", "type": "function", + "function": {"name": "get_weather", "arguments": "{\"location\": \"Tokyo\"}"}}]}}]} +``` + +The model correctly generates `` tags which are parsed and converted to OpenAI `tool_calls` format. Generation stops after the tool call so the client can execute the function and send results back. + +### Sending tool results back + +```bash +curl -N http://localhost:8080/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [ + {"role": "user", "content": "What is the weather in Tokyo?"}, + {"role": "assistant", "content": null, "tool_calls": [ + {"id": "call_1", "type": "function", + "function": {"name": "get_weather", "arguments": "{\"location\": \"Tokyo\"}"}} + ]}, + {"role": "tool", "tool_call_id": "call_1", + "content": "{\"temperature\": 22, \"condition\": \"sunny\"}"} + ], + "max_tokens": 200, + "stream": true + }' +``` + +### Using with Claude Code + +The server natively supports the Anthropic Messages API (`POST /v1/messages`) — no proxy needed: + +```bash +# Start the Flash-MoE server +./infer --serve 8080 + +# Point Claude Code at it +export ANTHROPIC_BASE_URL=http://localhost:8080 +claude --model qwen3.5-397b-a17b +``` + +This gives you a 397B parameter model with tool calling running locally through Claude Code's agent framework — reading files, running commands, editing code — all on a single GPU. + +### Using with other OpenAI-compatible clients + +The server works directly with any OpenAI-compatible client: + +```python +# Python (openai SDK) +from openai import OpenAI +client = OpenAI(base_url="http://localhost:8080/v1", api_key="unused") +response = client.chat.completions.create( + model="qwen3.5-397b", + messages=[{"role": "user", "content": "Hello!"}], + stream=True +) +for chunk in response: + print(chunk.choices[0].delta.content or "", end="") +``` + +```bash +# aider +aider --model openai/qwen3.5-397b --openai-api-base http://localhost:8080/v1 + +# continue.dev (VS Code) — add to config.json: +# {"models": [{"provider": "openai", "model": "qwen3.5-397b", +# "apiBase": "http://localhost:8080/v1"}]} +``` + +### Custom system prompt + +Place a file at `~/.flash-moe/system.md` to override the default system prompt. + +## Architecture + +### Files + +``` +cuda_infer/ + infer.cu # Complete inference engine + HTTP server (~1700 lines) + kernels.cuh # 15 CUDA compute kernels (~570 lines) + Makefile + build_expert_index.py # Generate expert_index.json from safetensors + export_vocab.py # Generate vocab.bin from tokenizer.json + +bench_transfer.cu # Transfer path benchmarks (GDS, pread, cudaMemcpy) +``` + +### CUDA Kernels (ported from Metal) + +| Kernel | Purpose | +|--------|---------| +| `dequant_matvec_4bit_fma` | FMA-optimized 4-bit dequant matrix-vector multiply | +| `swiglu_fused` | SiLU(gate) × up activation | +| `rms_norm` / `rms_norm_bf16` | RMS normalization with f32/bf16 weights | +| `gated_delta_net_step` | GatedDeltaNet linear attention recurrence | +| `conv1d_step` | Depthwise conv1d (kernel=4) with SiLU | +| `attn_scores` / `attn_softmax` / `attn_values` | Full attention (Q@K^T, softmax, scores@V) | +| `moe_combine_residual` | Weighted expert sum + shared expert + residual | + +### Transfer Path Benchmarks + +Measured on RTX 4090 + Samsung 990 EVO Plus (PCIe 4.0 x4): + +| Path | Time (K=4/layer) | Throughput | +|------|-----------------|-----------| +| pread → cudaMemcpy (cold) | 8.3 ms | 3.4 GB/s | +| **GDS cuFileRead (cold)** | **5.3 ms** | **5.3 GB/s** | +| Warm cache (page cache hit) | 2.7 ms | 10.4 GB/s | +| GPU dequant K=4 experts | 0.08 ms | negligible | + +For cold reads, GDS is 37% faster than pread+cudaMemcpy. However, **pread with page cache** is the default because hot experts cached in RAM (2.7ms) beat GDS cold reads (5.3ms). With 64GB RAM, the page cache grows to ~50GB during sustained generation, caching roughly half the expert data. Set `ENABLE_GDS=1` to force GDS on low-RAM systems. + +### Key Differences from Apple Silicon Version + +| Aspect | Apple Silicon (Metal) | NVIDIA (CUDA) | +|--------|----------------------|---------------| +| Memory | Unified (GPU=CPU=SSD) | Discrete (PCIe bus) | +| SSD bandwidth | 17.5 GB/s | 5-7 GB/s (PCIe 4.0 x4) | +| GPU memory BW | ~400 GB/s | 1008 GB/s | +| SSD→GPU path | Direct (shared memory) | GDS or pread+cudaMemcpy | +| I/O+compute overlap | Cannot overlap (shared bus) | **Can overlap** (separate buses) | +| Pipeline | 3 Metal command buffers | CUDA streams | + +## Technical Details + +### Per-Token Pipeline (60 layers) + +For each layer: +1. **RMS norm** (input layernorm) — GPU +2. **Attention projections** (4-bit dequant matvec) — GPU +3. **Attention compute**: + - Linear (45 layers): conv1d → RMS norm Q/K → decay/beta → GatedDeltaNet recurrence → gated RMS norm + - Full (15 layers): Q/K RMS norm → RoPE → KV cache update → Q@K^T → softmax → scores@V → sigmoid gate +4. **Output projection** (dequant matvec) — GPU +5. **Residual + post-attention RMS norm** — GPU +6. **MoE routing**: dequant matvec → softmax → topK — GPU + CPU +7. **Shared expert forward**: gate+up → SwiGLU → down — GPU (overlapped with expert I/O) +8. **Expert loading**: K=4 parallel reads from SSD — GDS or pread +9. **Expert forward**: gate+up → SwiGLU → down × K — GPU +10. **MoE combine + residual** — GPU + +### Memory Usage + +| Component | Size | Location | +|-----------|------|----------| +| Non-expert weights | 5.2 GB | GPU VRAM | +| **VRAM expert cache** | **~17 GB** | **GPU VRAM (LRU, ~2500 experts)** | +| Scratch buffers | ~200 MB | GPU VRAM | +| KV cache (15 full-attn layers) | ~200 MB | GPU VRAM | +| Delta-net state (45 linear layers) | ~180 MB | GPU VRAM | +| **Total GPU VRAM** | **~23 GB** | | +| Process RSS | ~5.5 GB | CPU RAM | +| OS page cache | ~50 GB | CPU RAM (dynamic, caches SSD reads) | +| Expert data on disk | 203 GB | NVMe SSD | diff --git a/cuda_infer/build_expert_index.py b/cuda_infer/build_expert_index.py new file mode 100644 index 0000000..2fe30a6 --- /dev/null +++ b/cuda_infer/build_expert_index.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +"""Build expert_index.json from model safetensors for repack_experts.py. + +Scans the safetensors headers to find expert weight locations and writes +an index file compatible with repack_experts.py. + +Usage: + python build_expert_index.py --model /path/to/model-safetensors --output expert_index.json +""" + +import argparse +import json +import struct +import os +import re +from pathlib import Path + + +def parse_safetensors_header(filepath): + """Parse safetensors header, return (header_dict, data_start_offset).""" + with open(filepath, 'rb') as f: + header_len = struct.unpack(' filename + expert_tensors = {} # (layer, "gate_proj.weight") -> (tensor_name, filename) + for name, filename in weight_map.items(): + m = expert_pattern.match(name) + if m: + layer = int(m.group(1)) + proj = m.group(2) + comp = m.group(3) + key = (layer, f"{proj}.{comp}") + expert_tensors[key] = (name, filename) + + # Collect unique layers + layers = sorted(set(l for l, _ in expert_tensors.keys())) + print(f"Found {len(layers)} layers with expert weights") + print(f"Found {len(expert_tensors)} expert tensor entries") + + # Parse headers for all needed files + needed_files = set(fn for _, fn in expert_tensors.values()) + print(f"Parsing {len(needed_files)} safetensors headers...") + + header_cache = {} + for fn in sorted(needed_files): + fp = model_path / fn + header_cache[fn] = parse_safetensors_header(str(fp)) + + # Build expert_reads index + # For each layer and component, compute: + # - file: safetensors filename + # - abs_offset: absolute byte offset to expert 0's data + # - expert_stride: byte stride between consecutive experts + # - expert_size: bytes per expert for this component + # - total_size: total bytes for all 512 experts + # - shape: [num_experts, out_dim, packed_cols_or_groups] + + expert_reads = {} + for layer in layers: + layer_reads = {} + for comp_key in ['gate_proj.weight', 'gate_proj.scales', 'gate_proj.biases', + 'up_proj.weight', 'up_proj.scales', 'up_proj.biases', + 'down_proj.weight', 'down_proj.scales', 'down_proj.biases']: + key = (layer, comp_key) + if key not in expert_tensors: + print(f" WARNING: missing {comp_key} for layer {layer}") + continue + + tensor_name, filename = expert_tensors[key] + header, data_start = header_cache[filename] + + # Find tensor in header (skip __metadata__) + tensor_info = None + for k, v in header.items(): + if k == '__metadata__': + continue + if k == tensor_name: + tensor_info = v + break + + if tensor_info is None: + print(f" WARNING: tensor {tensor_name} not found in {filename}") + continue + + offsets = tensor_info['data_offsets'] + shape = tensor_info['shape'] + dtype = tensor_info['dtype'] + + abs_offset = data_start + offsets[0] + total_size = offsets[1] - offsets[0] + + # shape is [num_experts, out_dim, packed_dim] + num_experts = shape[0] + expert_size = total_size // num_experts + expert_stride = expert_size + + layer_reads[comp_key] = { + 'file': filename, + 'abs_offset': abs_offset, + 'expert_stride': expert_stride, + 'expert_size': expert_size, + 'total_size': total_size, + 'shape': shape, + } + + expert_reads[str(layer)] = layer_reads + + # Write index + index = { + 'model_path': str(model_path), + 'expert_reads': expert_reads, + } + + with open(args.output, 'w') as f: + json.dump(index, f, indent=2) + + print(f"\nWrote {args.output}") + print(f" {len(layers)} layers, 9 components each") + + # Verify sizes are consistent across layers + first_layer = expert_reads[str(layers[0])] + expert_size_total = sum(first_layer[c]['expert_size'] for c in first_layer) + num_experts = first_layer[list(first_layer.keys())[0]]['shape'][0] + print(f" Experts per layer: {num_experts}") + print(f" Expert size: {expert_size_total} bytes ({expert_size_total/1024/1024:.2f} MB)") + for comp, info in first_layer.items(): + print(f" {comp:25s} {info['expert_size']:>8d} bytes shape={info['shape']}") + + ok = True + for layer in layers[1:]: + for comp in first_layer: + if comp not in expert_reads[str(layer)]: + print(f" MISSING: layer {layer} {comp}") + ok = False + continue + if expert_reads[str(layer)][comp]['expert_size'] != first_layer[comp]['expert_size']: + print(f" MISMATCH: layer {layer} {comp}") + ok = False + if ok: + print(" Cross-layer consistency: OK") + + +if __name__ == '__main__': + main() diff --git a/cuda_infer/configure.py b/cuda_infer/configure.py new file mode 100644 index 0000000..eeddf85 --- /dev/null +++ b/cuda_infer/configure.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""Generate build command for any MoE model from its model_weights.json config. + +Usage: + python3 configure.py [--manifest model_weights.json] [--output Makefile.model] + +Reads the config section from model_weights.json and outputs either: + 1. A Makefile with the correct -D flags (default) + 2. The nvcc command to stdout (--print-cmd) +""" +import json +import sys +import argparse + +def main(): + parser = argparse.ArgumentParser(description='Configure build for a specific MoE model') + parser.add_argument('--manifest', default='model_weights.json', + help='Path to model_weights.json') + parser.add_argument('--output', default=None, + help='Output Makefile (default: print command to stdout)') + parser.add_argument('--print-cmd', action='store_true', + help='Print nvcc command instead of writing Makefile') + args = parser.parse_args() + + with open(args.manifest) as f: + manifest = json.load(f) + + cfg = manifest.get('config', {}) + + # Map config keys to C #define names + defines = { + 'HIDDEN_DIM': cfg.get('hidden_size', 4096), + 'NUM_LAYERS': cfg.get('num_hidden_layers', 60), + 'NUM_ATTN_HEADS': cfg.get('num_attention_heads', 32), + 'NUM_KV_HEADS': cfg.get('num_key_value_heads', 2), + 'HEAD_DIM': cfg.get('head_dim', 256), + 'VOCAB_SIZE': cfg.get('vocab_size', 248320), + 'NUM_EXPERTS': cfg.get('num_experts', 512), + 'MOE_INTERMEDIATE': cfg.get('moe_intermediate_size', 1024), + 'SHARED_INTERMEDIATE': cfg.get('shared_expert_intermediate_size', 1024), + 'FULL_ATTN_INTERVAL': cfg.get('full_attention_interval', 4), + 'LINEAR_NUM_V_HEADS': cfg.get('linear_num_value_heads', 64), + 'LINEAR_NUM_K_HEADS': cfg.get('linear_num_key_heads', 16), + 'LINEAR_KEY_DIM': cfg.get('linear_key_head_dim', 128), + 'LINEAR_VALUE_DIM': cfg.get('linear_value_head_dim', 128), + 'CONV_KERNEL_SIZE': cfg.get('linear_conv_kernel_dim', 4), + } + + # Float defines + float_defines = { + 'ROPE_THETA': cfg.get('rope_theta', 10000000.0), + 'PARTIAL_ROTARY': cfg.get('partial_rotary_factor', 0.25), + } + + # Build -D flags + dflags = ' '.join(f'-D{k}={v}' for k, v in defines.items()) + dflags += ' ' + ' '.join(f'-D{k}={v}f' for k, v in float_defines.items()) + + # Model name for the binary + hidden = defines['HIDDEN_DIM'] + layers = defines['NUM_LAYERS'] + experts = defines['NUM_EXPERTS'] + model_name = f"qwen_{hidden}x{layers}x{experts}" + + nvcc_cmd = ( + f'nvcc -O2 -Wno-deprecated-gpu-targets -diag-suppress 1650 ' + f'{dflags} ' + f'-o infer_{model_name} infer.cu tokenizer_impl.o -lpthread' + ) + + if args.print_cmd: + print(nvcc_cmd) + return + + print(f'Model: {model_name}') + print(f' hidden={hidden}, layers={layers}, experts={experts}') + print(f' K={cfg.get("num_experts_per_tok", "?")}, ' + f'intermediate={defines["MOE_INTERMEDIATE"]}') + + if args.output: + with open(args.output, 'w') as f: + f.write(f'# Auto-generated for {model_name}\n') + f.write(f'# From: {args.manifest}\n\n') + f.write(f'NVCC ?= nvcc\n') + f.write(f'MODEL_DFLAGS = {dflags}\n\n') + f.write(f'infer_{model_name}: infer.cu kernels.cuh tokenizer_impl.o\n') + f.write(f'\t$(NVCC) -O2 -Wno-deprecated-gpu-targets -diag-suppress 1650 ' + f'$(MODEL_DFLAGS) -o $@ infer.cu tokenizer_impl.o -lpthread\n\n') + f.write(f'tokenizer_impl.o: tokenizer_impl.c ../metal_infer/tokenizer.h\n') + f.write(f'\tgcc -O2 -c tokenizer_impl.c -o tokenizer_impl.o\n') + print(f'Wrote {args.output}') + print(f'Build with: make -f {args.output}') + else: + print(f'\nBuild command:') + print(f' {nvcc_cmd}') + +if __name__ == '__main__': + main() diff --git a/cuda_infer/export_vocab.py b/cuda_infer/export_vocab.py new file mode 100644 index 0000000..8755023 --- /dev/null +++ b/cuda_infer/export_vocab.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +import json, struct, os, sys + +tok_path = sys.argv[1] if len(sys.argv) > 1 else 'tokenizer.json' +out_path = sys.argv[2] if len(sys.argv) > 2 else 'vocab.bin' + +with open(tok_path) as f: + t = json.load(f) + +vocab = t['model']['vocab'] +added = {tok['content']: tok['id'] for tok in t['added_tokens']} + +all_tokens = dict(vocab) +all_tokens.update(added) + +max_id = max(all_tokens.values()) +num_entries = max_id + 1 + +id_to_str = [''] * num_entries +for s, i in all_tokens.items(): + id_to_str[i] = s + +with open(out_path, 'wb') as f: + f.write(struct.pack(' +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kernels.cuh" + +// tokenizer.h: declarations only (impl in tokenizer_impl.c, linked separately) +extern "C" { +#include "../metal_infer/tokenizer.h" +} + +// ============================================================================ +// Model constants — defaults for Qwen3.5-397B-A17B +// Override at compile time with -DHIDDEN_DIM=3072 etc. for other models. +// ============================================================================ + +#ifndef HIDDEN_DIM +#define HIDDEN_DIM 4096 +#endif +#ifndef NUM_LAYERS +#define NUM_LAYERS 60 +#endif +#ifndef NUM_ATTN_HEADS +#define NUM_ATTN_HEADS 32 +#endif +#ifndef NUM_KV_HEADS +#define NUM_KV_HEADS 2 +#endif +#ifndef HEAD_DIM +#define HEAD_DIM 256 +#endif +#ifndef VOCAB_SIZE +#define VOCAB_SIZE 248320 +#endif +#define RMS_NORM_EPS 1e-6f +#ifndef NUM_EXPERTS +#define NUM_EXPERTS 512 +#endif +#ifndef MOE_INTERMEDIATE +#define MOE_INTERMEDIATE 1024 +#endif +#ifndef SHARED_INTERMEDIATE +#define SHARED_INTERMEDIATE 1024 +#endif +#ifndef FULL_ATTN_INTERVAL +#define FULL_ATTN_INTERVAL 4 +#endif +#define GROUP_SIZE_C 64 + +// Linear attention (GatedDeltaNet) +#ifndef LINEAR_NUM_V_HEADS +#define LINEAR_NUM_V_HEADS 64 +#endif +#ifndef LINEAR_NUM_K_HEADS +#define LINEAR_NUM_K_HEADS 16 +#endif +#ifndef LINEAR_KEY_DIM +#define LINEAR_KEY_DIM 128 +#endif +#ifndef LINEAR_VALUE_DIM +#define LINEAR_VALUE_DIM 128 +#endif +#define LINEAR_TOTAL_KEY (LINEAR_NUM_K_HEADS * LINEAR_KEY_DIM) +#define LINEAR_TOTAL_VALUE (LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM) +#define LINEAR_CONV_DIM (LINEAR_TOTAL_KEY * 2 + LINEAR_TOTAL_VALUE) +#ifndef CONV_KERNEL_SIZE +#define CONV_KERNEL_SIZE 4 +#endif + +// Full attention +#ifndef ROPE_THETA +#define ROPE_THETA 10000000.0f +#endif +#ifndef PARTIAL_ROTARY +#define PARTIAL_ROTARY 0.25f +#endif +#define ROTARY_DIM ((int)(HEAD_DIM * PARTIAL_ROTARY)) +#define MAX_SEQ_LEN 4096 + +// Expert layout — computed from dimensions if not overridden +#ifndef EXPERT_SIZE +#define EXPERT_SIZE (MOE_INTERMEDIATE * (HIDDEN_DIM/8) * 4 \ + + MOE_INTERMEDIATE * (HIDDEN_DIM/GROUP_SIZE_C) * 2 * 2 \ + + MOE_INTERMEDIATE * (HIDDEN_DIM/8) * 4 \ + + MOE_INTERMEDIATE * (HIDDEN_DIM/GROUP_SIZE_C) * 2 * 2 \ + + HIDDEN_DIM * (MOE_INTERMEDIATE/8) * 4 \ + + HIDDEN_DIM * (MOE_INTERMEDIATE/GROUP_SIZE_C) * 2 * 2) +#endif +#define MAX_K 8 + +// Quant format (0 = MLX affine 4-bit, 1 = GGUF) +static int g_quant_format = 0; // set at startup from manifest + +// Runtime expert size — defaults to compile-time MLX value, overridden for GGUF +static size_t g_expert_size = EXPERT_SIZE; + +// GGUF expert component offsets and types (populated from layout.json for GGUF) +static size_t g_gguf_gate_offset = 0; +static size_t g_gguf_gate_size = 0; +static int g_gguf_gate_type = 12; // Q4_K default +static size_t g_gguf_up_offset = 0; +static size_t g_gguf_up_size = 0; +static int g_gguf_up_type = 12; +static size_t g_gguf_down_offset = 0; +static size_t g_gguf_down_size = 0; +static int g_gguf_down_type = 14; // Q6_K default +static int g_gguf_down_type_per_layer[256] = {0}; // per-layer down type (i1 quant mixes Q4_K/Q6_K) + +#define CHECK_CUDA(call) do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(1); \ + } \ +} while(0) + +static double now_ms(void) { + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; +} + +static inline float bf16_to_f32_host(uint16_t bf16) { + uint32_t tmp = (uint32_t)bf16 << 16; + float f; + memcpy(&f, &tmp, sizeof(f)); + return f; +} + +// ============================================================================ +// BPE byte-unicode decoding (Ġ→space, Ċ→newline, etc.) +// ============================================================================ +// GPT-2 BPE maps bytes 0-255 to Unicode codepoints to avoid control chars. +// We need to reverse this mapping when displaying tokens. + +static int g_bpe_byte_table[256]; // unicode codepoint → original byte +static int g_bpe_byte_table_built = 0; + +static void build_bpe_decode_table(void) { + if (g_bpe_byte_table_built) return; + // Build forward map: byte → unicode (same as GPT-2 bytes_to_unicode) + int unicode_map[256]; + int n = 0; + for (int b = 0; b < 256; b++) { + if ((b >= '!' && b <= '~') || (b >= 0xA1 && b <= 0xAC) || (b >= 0xAE && b <= 0xFF)) { + unicode_map[b] = b; + } else { + unicode_map[b] = 256 + n; + n++; + } + } + // Build reverse: unicode → byte + memset(g_bpe_byte_table, -1, sizeof(g_bpe_byte_table)); + for (int b = 0; b < 256; b++) { + if (unicode_map[b] < 256) + g_bpe_byte_table[unicode_map[b]] = b; + } + // Handle the 256+ range (Ġ=288→0x20=space, Ċ=266→0x0A=newline, etc.) + g_bpe_byte_table_built = 1; +} + +// Decode a BPE token string to raw bytes, return length +static int bpe_decode_token(const char *token, char *out, int max_out) { + build_bpe_decode_table(); + int j = 0; + const unsigned char *p = (const unsigned char *)token; + while (*p && j < max_out - 1) { + // Decode UTF-8 codepoint + uint32_t cp; + int bytes; + if (*p < 0x80) { cp = *p; bytes = 1; } + else if ((*p & 0xE0) == 0xC0) { cp = *p & 0x1F; bytes = 2; } + else if ((*p & 0xF0) == 0xE0) { cp = *p & 0x0F; bytes = 3; } + else { cp = *p & 0x07; bytes = 4; } + for (int i = 1; i < bytes && p[i]; i++) + cp = (cp << 6) | (p[i] & 0x3F); + p += bytes; + + // Map unicode codepoint back to byte + if (cp < 256 && g_bpe_byte_table[cp] >= 0) { + out[j++] = (char)g_bpe_byte_table[cp]; + } else if (cp >= 256 && cp < 512) { + // Extended range: codepoints 256-511 map to bytes that were remapped + // Build on-the-fly: find the byte whose unicode_map == cp + int found = 0; + int n = 0; + for (int b = 0; b < 256 && !found; b++) { + if ((b >= '!' && b <= '~') || (b >= 0xA1 && b <= 0xAC) || (b >= 0xAE && b <= 0xFF)) + continue; + if (256 + n == (int)cp) { out[j++] = (char)b; found = 1; } + n++; + } + if (!found) out[j++] = '?'; + } else { + // Pass through other Unicode as-is (UTF-8 encode back) + if (cp < 0x80) out[j++] = cp; + else if (cp < 0x800 && j + 1 < max_out) { + out[j++] = 0xC0 | (cp >> 6); + out[j++] = 0x80 | (cp & 0x3F); + } else if (cp < 0x10000 && j + 2 < max_out) { + out[j++] = 0xE0 | (cp >> 12); + out[j++] = 0x80 | ((cp >> 6) & 0x3F); + out[j++] = 0x80 | (cp & 0x3F); + } + } + } + out[j] = '\0'; + return j; +} + +static void print_token(const char *token) { + char decoded[1024]; + bpe_decode_token(token, decoded, sizeof(decoded)); + printf("%s", decoded); +} + +// ============================================================================ +// Minimal JSON parser for model_weights.json manifest +// ============================================================================ + +typedef struct { + char name[256]; + size_t offset; + size_t size; + char dtype[8]; // "U32", "BF16", "F32" + int shape[4]; + int ndim; + int gguf_type; // GGML quant type (0=F32, 12=Q4_K, 14=Q6_K, etc.) — only used for GGUF +} TensorInfo; + +typedef struct { + TensorInfo *tensors; + int num_tensors; +} TensorManifest; + +// Simple JSON string extraction (no dependencies) +static const char *json_find_key(const char *json, const char *key) { + char pattern[512]; + snprintf(pattern, sizeof(pattern), "\"%s\"", key); + return strstr(json, pattern); +} + +static TensorManifest *load_manifest(const char *path) { + FILE *f = fopen(path, "r"); + if (!f) { fprintf(stderr, "Cannot open manifest %s\n", path); return NULL; } + fseek(f, 0, SEEK_END); + long sz = ftell(f); + fseek(f, 0, SEEK_SET); + char *json = (char *)malloc(sz + 1); + fread(json, 1, sz, f); + json[sz] = '\0'; + fclose(f); + + // Find "tensors" section + const char *tensors_start = json_find_key(json, "tensors"); + if (!tensors_start) { fprintf(stderr, "No tensors in manifest\n"); free(json); return NULL; } + + // Count tensors (count "offset" occurrences) + int count = 0; + const char *p = tensors_start; + while ((p = strstr(p + 1, "\"offset\"")) != NULL) count++; + + TensorManifest *m = (TensorManifest *)calloc(1, sizeof(TensorManifest)); + m->tensors = (TensorInfo *)calloc(count, sizeof(TensorInfo)); + m->num_tensors = 0; + + // Parse each tensor entry + p = tensors_start; + while (m->num_tensors < count) { + // Find next tensor name (key before the {) + p = strchr(p + 1, '"'); + if (!p) break; + p++; // skip opening quote + const char *name_end = strchr(p, '"'); + if (!name_end) break; + + // Skip non-tensor keys + const char *brace = strchr(name_end, '{'); + if (!brace) break; + + // Check if this is a tensor entry (has "offset") + const char *next_brace = strchr(brace + 1, '}'); + if (!next_brace) break; + char *offset_key = strstr((char *)brace, "\"offset\""); + if (!offset_key || offset_key > next_brace) { + p = next_brace; + continue; + } + + TensorInfo *t = &m->tensors[m->num_tensors]; + size_t nlen = name_end - p; + if (nlen >= sizeof(t->name)) nlen = sizeof(t->name) - 1; + memcpy(t->name, p, nlen); + t->name[nlen] = '\0'; + + // Parse offset + char *colon = strchr(offset_key, ':'); + if (colon) t->offset = strtoul(colon + 1, NULL, 10); + + // Parse size + char *size_key = strstr((char *)brace, "\"size\""); + if (size_key && size_key < next_brace) { + colon = strchr(size_key, ':'); + if (colon) t->size = strtoul(colon + 1, NULL, 10); + } + + // Parse dtype + char *dtype_key = strstr((char *)brace, "\"dtype\""); + if (dtype_key && dtype_key < next_brace) { + char *dq = strchr(dtype_key + 7, '"'); + if (dq) { + dq++; + char *dq_end = strchr(dq, '"'); + if (dq_end) { + size_t dl = dq_end - dq; + if (dl < sizeof(t->dtype)) { memcpy(t->dtype, dq, dl); t->dtype[dl] = '\0'; } + } + } + } + + // Parse shape + char *shape_key = strstr((char *)brace, "\"shape\""); + if (shape_key && shape_key < next_brace) { + char *sb = strchr(shape_key, '['); + if (sb) { + t->ndim = 0; + char *sp = sb + 1; + while (t->ndim < 4) { + while (*sp == ' ' || *sp == ',') sp++; + if (*sp == ']') break; + t->shape[t->ndim++] = atoi(sp); + while (*sp && *sp != ',' && *sp != ']') sp++; + } + } + } + + // Parse gguf_type (GGML quant type, only present for GGUF format) + t->gguf_type = 0; + char *gguf_key = strstr((char *)brace, "\"gguf_type\""); + if (gguf_key && gguf_key < next_brace) { + char *gcolon = strchr(gguf_key, ':'); + if (gcolon) t->gguf_type = atoi(gcolon + 1); + } + + m->num_tensors++; + p = next_brace; + } + + free(json); + return m; +} + +// Raw JSON for fallback tensor lookup +static char *g_manifest_json = NULL; +static size_t g_manifest_json_len = 0; + +// ============================================================================ +// Weight file (mmap'd model_weights.bin) +// ============================================================================ + +typedef struct { + void *data; + size_t size; + TensorManifest *manifest; +} WeightFile; + +static WeightFile *open_weights(const char *bin_path, const char *json_path) { + // Load raw JSON for fallback tensor lookup + { + FILE *jf = fopen(json_path, "r"); + if (jf) { + fseek(jf, 0, SEEK_END); + g_manifest_json_len = ftell(jf); + fseek(jf, 0, SEEK_SET); + g_manifest_json = (char *)malloc(g_manifest_json_len + 1); + fread(g_manifest_json, 1, g_manifest_json_len, jf); + g_manifest_json[g_manifest_json_len] = '\0'; + fclose(jf); + } + } + + int fd = open(bin_path, O_RDONLY); + if (fd < 0) { perror(bin_path); return NULL; } + struct stat st; + fstat(fd, &st); + void *data = mmap(NULL, st.st_size, PROT_READ, MAP_PRIVATE, fd, 0); + close(fd); + if (data == MAP_FAILED) { perror("mmap"); return NULL; } + madvise(data, st.st_size, MADV_SEQUENTIAL); + + WeightFile *wf = (WeightFile *)calloc(1, sizeof(WeightFile)); + wf->data = data; + wf->size = st.st_size; + wf->manifest = load_manifest(json_path); + if (!wf->manifest) { munmap(data, st.st_size); free(wf); return NULL; } + + printf("[weights] Loaded %.2f GB, %d tensors\n", + wf->size / (1024.0*1024*1024), wf->manifest->num_tensors); + return wf; +} + +static TensorInfo *find_tensor(WeightFile *wf, const char *name) { + for (int i = 0; i < wf->manifest->num_tensors; i++) { + if (strcmp(wf->manifest->tensors[i].name, name) == 0) + return &wf->manifest->tensors[i]; + } + // Fallback: if not found in parsed manifest, search raw JSON + // (the custom parser may have missed some entries) + return NULL; +} + +// Forward declarations for raw JSON fallback +// (defined below, used in upload_tensor and open_weights) + +static int find_tensor_in_json(const char *name, size_t *out_offset, size_t *out_size) { + if (!g_manifest_json) return 0; + // Search for "name": { ... "offset": N, "size": N ... } + char pattern[512]; + snprintf(pattern, sizeof(pattern), "\"%s\"", name); + const char *pos = strstr(g_manifest_json, pattern); + if (!pos) return 0; + const char *brace = strchr(pos + strlen(pattern), '{'); + if (!brace) return 0; + const char *end_brace = strchr(brace, '}'); + if (!end_brace) return 0; + + // Extract offset + const char *off_key = strstr(brace, "\"offset\""); + if (off_key && off_key < end_brace) { + const char *colon = strchr(off_key, ':'); + if (colon) *out_offset = strtoul(colon + 1, NULL, 10); + } else return 0; + + // Extract size + const char *sz_key = strstr(brace, "\"size\""); + if (sz_key && sz_key < end_brace) { + const char *colon = strchr(sz_key, ':'); + if (colon) *out_size = strtoul(colon + 1, NULL, 10); + } else return 0; + + return 1; +} + +// Extended version that also extracts gguf_type +static int find_tensor_gguf_type_in_json(const char *name) { + if (!g_manifest_json) return 0; + char pattern[512]; + snprintf(pattern, sizeof(pattern), "\"%s\"", name); + const char *pos = strstr(g_manifest_json, pattern); + if (!pos) return 0; + const char *brace = strchr(pos + strlen(pattern), '{'); + if (!brace) return 0; + const char *end_brace = strchr(brace, '}'); + if (!end_brace) return 0; + const char *gt_key = strstr(brace, "\"gguf_type\""); + if (gt_key && gt_key < end_brace) { + const char *colon = strchr(gt_key, ':'); + if (colon) return atoi(colon + 1); + } + return 0; +} + +static void *get_tensor_ptr(WeightFile *wf, const char *name) { + TensorInfo *t = find_tensor(wf, name); + if (!t) return NULL; + return (char *)wf->data + t->offset; +} + +// ============================================================================ +// GPU weight storage — all non-expert weights uploaded to VRAM +// ============================================================================ + +typedef struct { + // Per-layer weight pointers (on GPU) + struct { + // Input/post-attention norms (bf16) + uint16_t *input_norm_w; + uint16_t *post_attn_norm_w; + + // Full attention (15 layers) + uint32_t *q_w; uint16_t *q_s, *q_b; + uint32_t *k_w; uint16_t *k_s, *k_b; + uint32_t *v_w; uint16_t *v_s, *v_b; + uint32_t *o_w; uint16_t *o_s, *o_b; + uint16_t *q_norm_w, *k_norm_w; + + // Linear attention (45 layers) + uint32_t *qkv_w; uint16_t *qkv_s, *qkv_b; + uint32_t *z_w; uint16_t *z_s, *z_b; + uint32_t *b_w; uint16_t *b_s, *b_b; // beta projection + uint32_t *a_w; uint16_t *a_s, *a_b; // alpha projection + uint16_t *conv1d_w; + float *A_log; + uint16_t *dt_bias; + uint16_t *gated_norm_w; + uint32_t *out_proj_w; uint16_t *out_proj_s, *out_proj_b; + + // MoE routing + shared expert + uint32_t *gate_w; uint16_t *gate_s, *gate_b; + uint32_t *sg_w; uint16_t *sg_s, *sg_b; + uint32_t *su_w; uint16_t *su_s, *su_b; + uint32_t *sd_w; uint16_t *sd_s, *sd_b; + uint32_t *seg_w; uint16_t *seg_s, *seg_b; // shared_expert_gate + + // Fused QKV for full attention in GGUF mode (separate Q/K/V for MLX) + uint32_t *full_qkv_w; uint16_t *full_qkv_s, *full_qkv_b; + + int is_full; + + // GGUF quant types (populated at init, 0 for MLX) + int qt_q, qt_k, qt_v, qt_o; // full attention + int qt_qkv, qt_z, qt_b, qt_a, qt_out; // linear attention + int qt_gate, qt_sg, qt_su, qt_sd, qt_seg; // MoE + int qt_full_qkv; // fused QKV for full attn (GGUF) + } layers[NUM_LAYERS]; + + // Global weights + uint32_t *embed_w; uint16_t *embed_s, *embed_b; + uint32_t *lm_head_w; uint16_t *lm_head_s, *lm_head_b; + uint16_t *final_norm_w; + + // GGUF quant types for global weights + int qt_embed, qt_lm_head; + + // Scratch buffers (GPU) + float *buf_hidden; // [HIDDEN_DIM] + float *buf_normed; // [HIDDEN_DIM] + float *buf_residual; // [HIDDEN_DIM] + float *buf_attn_out; // [max(NUM_ATTN_HEADS*HEAD_DIM, LINEAR_TOTAL_VALUE)] + + // Attention projection outputs + float *buf_q_proj; // [NUM_ATTN_HEADS * HEAD_DIM * 2] or [LINEAR_CONV_DIM] + float *buf_k_proj; // [NUM_KV_HEADS * HEAD_DIM] + float *buf_v_proj; // [NUM_KV_HEADS * HEAD_DIM] + float *buf_z_proj; // [LINEAR_TOTAL_VALUE] + float *buf_beta_proj; // [LINEAR_NUM_V_HEADS] + float *buf_alpha_proj; // [LINEAR_NUM_V_HEADS] + + // Post-attention + float *buf_h_mid; // [HIDDEN_DIM] after o_proj + residual + float *buf_gate_scores; // [NUM_EXPERTS] + float *buf_shared_gate; // [SHARED_INTERMEDIATE] + float *buf_shared_up; // [SHARED_INTERMEDIATE] + float *buf_shared_out; // [HIDDEN_DIM] + + // Expert buffers + float *buf_expert_outs; // [MAX_K * HIDDEN_DIM] + void *buf_expert_data; // [MAX_K * EXPERT_SIZE] raw expert data on GPU + + // Linear attention state (persistent across tokens) + float *delta_state[NUM_LAYERS]; // [64 * 128 * 128] per linear layer + float *conv_state[NUM_LAYERS]; // [3 * 12288] per linear layer + float *buf_conv_output; // [LINEAR_CONV_DIM] + float *buf_g_decay; // [LINEAR_NUM_V_HEADS] + float *buf_beta_gate; // [LINEAR_NUM_V_HEADS] + float *buf_delta_output; // [LINEAR_TOTAL_VALUE] + + // Full attention KV cache (persistent) + float *kv_k[NUM_LAYERS]; // [MAX_SEQ_LEN * NUM_KV_HEADS * HEAD_DIM] + float *kv_v[NUM_LAYERS]; // same + int kv_len[NUM_LAYERS]; // current seq length per full-attn layer + float *buf_attn_scores; // [NUM_ATTN_HEADS * MAX_SEQ_LEN] + float *buf_q; // [NUM_ATTN_HEADS * HEAD_DIM] deinterleaved Q + float *buf_q_gate; // [NUM_ATTN_HEADS * HEAD_DIM] Q gate + + // Logits + float *buf_logits; // [VOCAB_SIZE] on GPU + float *h_logits; // [VOCAB_SIZE] pinned host memory + + // Expert I/O staging (pinned host) + void *h_expert_buf[MAX_K]; + + // Pre-allocated expert weights buffer (avoids per-layer cudaMalloc) + float *buf_expert_weights; // [MAX_K] on GPU + + // GDS handles (NULL if GDS not available) + int gds_available; + CUfileHandle_t gds_handles[NUM_LAYERS]; + + // CUDA streams for I/O overlap + cudaStream_t stream_compute; + cudaStream_t stream_transfer; + + // Expert file descriptors + int expert_fds[NUM_LAYERS]; + + // GPU memory for bulk weight upload + void *d_weights; // single allocation for all non-expert weights + size_t d_weights_size; + + // ---- VRAM expert cache ---- + // Frequency-weighted LRU cache of experts in GPU memory. + // Eviction score = access_count * FREQ_WEIGHT + last_used. + // Hot experts (high access_count) survive even if not used for a few tokens. + void *vram_cache_pool; // [vram_cache_capacity * EXPERT_SIZE] GPU memory + int vram_cache_capacity; // max experts that fit + int vram_cache_used; // current fill level + uint64_t vram_cache_clock; // clock (increments per access) + // Direct-mapped lookup: cache_map[layer][expert] = slot index (-1 = not cached) + int cache_map[NUM_LAYERS][NUM_EXPERTS]; + // Per-slot metadata + struct { + int layer; + int expert_id; + uint64_t last_used; // clock value at last access + uint32_t access_count; // total accesses since cached + } *cache_slots; + +} Model; + +// ============================================================================ +// Upload a tensor from mmap to GPU, return device pointer +// ============================================================================ + +static void *upload_tensor(WeightFile *wf, const char *name, void *d_base, size_t *d_offset) { + TensorInfo *t = find_tensor(wf, name); + size_t t_offset = 0, t_size = 0; + if (t) { + t_offset = t->offset; + t_size = t->size; + } else if (find_tensor_in_json(name, &t_offset, &t_size)) { + // Fallback: found in raw JSON + } else { + fprintf(stderr, "WARNING: tensor '%s' not found\n", name); + return NULL; + } + void *src = (char *)wf->data + t_offset; + void *dst = (char *)d_base + *d_offset; + CHECK_CUDA(cudaMemcpy(dst, src, t_size, cudaMemcpyHostToDevice)); + *d_offset += (t_size + 63) & ~63ULL; + return dst; +} + +// Upload F32 tensor data as BF16 (for GGUF: norms, dt_bias, etc. are F32 but kernels expect bf16) +static void *upload_tensor_f32_as_bf16(WeightFile *wf, const char *name, void *d_base, size_t *d_offset) { + TensorInfo *t = find_tensor(wf, name); + size_t t_offset = 0, t_size = 0; + if (t) { t_offset = t->offset; t_size = t->size; } + else if (find_tensor_in_json(name, &t_offset, &t_size)) {} + else { fprintf(stderr, "WARNING: tensor '%s' not found\n", name); return NULL; } + + // Convert F32 → BF16 on CPU + size_t n_floats = t_size / sizeof(float); + const float *src = (const float *)((char *)wf->data + t_offset); + uint16_t *bf16_buf = (uint16_t *)malloc(n_floats * sizeof(uint16_t)); + for (size_t i = 0; i < n_floats; i++) { + uint32_t bits; + memcpy(&bits, &src[i], 4); + bf16_buf[i] = (uint16_t)(bits >> 16); // F32 → BF16: take upper 16 bits + } + void *dst = (char *)d_base + *d_offset; + size_t bf16_size = n_floats * sizeof(uint16_t); + CHECK_CUDA(cudaMemcpy(dst, bf16_buf, bf16_size, cudaMemcpyHostToDevice)); + free(bf16_buf); + *d_offset += (bf16_size + 63) & ~63ULL; + return dst; +} + +// Helper macro for uploading weight triplets (weight, scales, biases) +#define UPLOAD_WEIGHT_TRIPLET(prefix, w_field, s_field, b_field) do { \ + char _n[256]; \ + snprintf(_n, sizeof(_n), "%s.weight", prefix); \ + model->w_field = (uint32_t *)upload_tensor(wf, _n, model->d_weights, &off); \ + snprintf(_n, sizeof(_n), "%s.scales", prefix); \ + model->s_field = (uint16_t *)upload_tensor(wf, _n, model->d_weights, &off); \ + snprintf(_n, sizeof(_n), "%s.biases", prefix); \ + model->b_field = (uint16_t *)upload_tensor(wf, _n, model->d_weights, &off); \ +} while(0) + +#define UPLOAD_LAYER_TRIPLET(layer_idx, prefix, w_field, s_field, b_field) do { \ + char _n[256]; \ + snprintf(_n, sizeof(_n), "model.layers.%d." prefix ".weight", layer_idx); \ + model->layers[layer_idx].w_field = (uint32_t *)upload_tensor(wf, _n, model->d_weights, &off); \ + snprintf(_n, sizeof(_n), "model.layers.%d." prefix ".scales", layer_idx); \ + model->layers[layer_idx].s_field = (uint16_t *)upload_tensor(wf, _n, model->d_weights, &off); \ + snprintf(_n, sizeof(_n), "model.layers.%d." prefix ".biases", layer_idx); \ + model->layers[layer_idx].b_field = (uint16_t *)upload_tensor(wf, _n, model->d_weights, &off); \ +} while(0) + +// GGUF: single tensor per weight (no separate scales/biases) +#define UPLOAD_WEIGHT_SINGLE(prefix, w_field) do { \ + char _n[256]; \ + snprintf(_n, sizeof(_n), "%s.weight", prefix); \ + model->w_field = (uint32_t *)upload_tensor(wf, _n, model->d_weights, &off); \ +} while(0) + +#define UPLOAD_LAYER_SINGLE(layer_idx, prefix, w_field) do { \ + char _n[256]; \ + snprintf(_n, sizeof(_n), "model.layers.%d." prefix ".weight", layer_idx); \ + model->layers[layer_idx].w_field = (uint32_t *)upload_tensor(wf, _n, model->d_weights, &off); \ +} while(0) + +// Helper to look up gguf_type for a tensor name +static int lookup_gguf_type(WeightFile *wf, const char *name) { + TensorInfo *t = find_tensor(wf, name); + if (t) return t->gguf_type; + return find_tensor_gguf_type_in_json(name); +} + +// ============================================================================ +// Model initialization +// ============================================================================ + +static Model *model_init(WeightFile *wf, const char *expert_dir, int K) { + Model *model = (Model *)calloc(1, sizeof(Model)); + + printf("[init] Uploading %.2f GB of non-expert weights to GPU...\n", + wf->size / (1024.0*1024*1024)); + double t0 = now_ms(); + + // Allocate single GPU buffer for all weights (slightly over-allocate for alignment) + model->d_weights_size = wf->size + NUM_LAYERS * 64 * 100; // extra for alignment padding + CHECK_CUDA(cudaMalloc(&model->d_weights, model->d_weights_size)); + + size_t off = 0; + + if (g_quant_format == 1) { + // ================================================================ + // GGUF: single tensor per weight (no separate scales/biases) + // S and B pointers remain NULL (calloc'd to 0). + // ================================================================ + UPLOAD_WEIGHT_SINGLE("model.embed_tokens", embed_w); + model->qt_embed = lookup_gguf_type(wf, "model.embed_tokens.weight"); + UPLOAD_WEIGHT_SINGLE("lm_head", lm_head_w); + model->qt_lm_head = lookup_gguf_type(wf, "lm_head.weight"); + + { + char n[] = "model.norm.weight"; + model->final_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + } + + for (int i = 0; i < NUM_LAYERS; i++) { + int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0); + model->layers[i].is_full = is_full; + char n[256]; + + // Norms (always f32 or bf16 in GGUF — uploaded as raw data) + snprintf(n, sizeof(n), "model.layers.%d.input_layernorm.weight", i); + model->layers[i].input_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.post_attention_layernorm.weight", i); + model->layers[i].post_attn_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + + if (is_full) { + // GGUF full attention: separate Q/K/V (not fused) + UPLOAD_LAYER_SINGLE(i, "self_attn.q_proj", q_w); + { snprintf(n, sizeof(n), "model.layers.%d.self_attn.q_proj.weight", i); + model->layers[i].qt_q = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "self_attn.k_proj", k_w); + { snprintf(n, sizeof(n), "model.layers.%d.self_attn.k_proj.weight", i); + model->layers[i].qt_k = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "self_attn.v_proj", v_w); + { snprintf(n, sizeof(n), "model.layers.%d.self_attn.v_proj.weight", i); + model->layers[i].qt_v = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "self_attn.o_proj", o_w); + { snprintf(n, sizeof(n), "model.layers.%d.self_attn.o_proj.weight", i); + model->layers[i].qt_o = lookup_gguf_type(wf, n); } + + // Q/K norms (F32 in GGUF → convert to bf16) + snprintf(n, sizeof(n), "model.layers.%d.self_attn.q_norm.weight", i); + model->layers[i].q_norm_w = (uint16_t *)upload_tensor_f32_as_bf16(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.self_attn.k_norm.weight", i); + model->layers[i].k_norm_w = (uint16_t *)upload_tensor_f32_as_bf16(wf, n, model->d_weights, &off); + } else { + UPLOAD_LAYER_SINGLE(i, "linear_attn.in_proj_qkv", qkv_w); + { snprintf(n, sizeof(n), "model.layers.%d.linear_attn.in_proj_qkv.weight", i); + model->layers[i].qt_qkv = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "linear_attn.in_proj_z", z_w); + { snprintf(n, sizeof(n), "model.layers.%d.linear_attn.in_proj_z.weight", i); + model->layers[i].qt_z = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "linear_attn.in_proj_b", b_w); + { snprintf(n, sizeof(n), "model.layers.%d.linear_attn.in_proj_b.weight", i); + model->layers[i].qt_b = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "linear_attn.in_proj_a", a_w); + { snprintf(n, sizeof(n), "model.layers.%d.linear_attn.in_proj_a.weight", i); + model->layers[i].qt_a = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "linear_attn.out_proj", out_proj_w); + { snprintf(n, sizeof(n), "model.layers.%d.linear_attn.out_proj.weight", i); + model->layers[i].qt_out = lookup_gguf_type(wf, n); } + + // F32 tensors that kernels read as bf16 — convert during upload + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.conv1d.weight", i); + model->layers[i].conv1d_w = (uint16_t *)upload_tensor_f32_as_bf16(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.A_log", i); + model->layers[i].A_log = (float *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.dt_bias", i); + // GGUF: keep dt_bias as F32 (store pointer in dt_bias field, kernel reads as F32) + model->layers[i].dt_bias = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.norm.weight", i); + model->layers[i].gated_norm_w = (uint16_t *)upload_tensor_f32_as_bf16(wf, n, model->d_weights, &off); + } + + // MoE routing + shared expert (all layers) + UPLOAD_LAYER_SINGLE(i, "mlp.gate", gate_w); + { snprintf(n, sizeof(n), "model.layers.%d.mlp.gate.weight", i); + model->layers[i].qt_gate = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "mlp.shared_expert.gate_proj", sg_w); + { snprintf(n, sizeof(n), "model.layers.%d.mlp.shared_expert.gate_proj.weight", i); + model->layers[i].qt_sg = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "mlp.shared_expert.up_proj", su_w); + { snprintf(n, sizeof(n), "model.layers.%d.mlp.shared_expert.up_proj.weight", i); + model->layers[i].qt_su = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "mlp.shared_expert.down_proj", sd_w); + { snprintf(n, sizeof(n), "model.layers.%d.mlp.shared_expert.down_proj.weight", i); + model->layers[i].qt_sd = lookup_gguf_type(wf, n); } + UPLOAD_LAYER_SINGLE(i, "mlp.shared_expert_gate", seg_w); + { snprintf(n, sizeof(n), "model.layers.%d.mlp.shared_expert_gate.weight", i); + model->layers[i].qt_seg = lookup_gguf_type(wf, n); } + } + } else { + // ================================================================ + // MLX affine 4-bit: weight triplets (W, scales, biases) + // ================================================================ + UPLOAD_WEIGHT_TRIPLET("model.embed_tokens", embed_w, embed_s, embed_b); + UPLOAD_WEIGHT_TRIPLET("lm_head", lm_head_w, lm_head_s, lm_head_b); + + { + char n[] = "model.norm.weight"; + model->final_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + } + + // Per-layer weights + for (int i = 0; i < NUM_LAYERS; i++) { + int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0); + model->layers[i].is_full = is_full; + + // Norms + { + char n[256]; + snprintf(n, sizeof(n), "model.layers.%d.input_layernorm.weight", i); + model->layers[i].input_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.post_attention_layernorm.weight", i); + model->layers[i].post_attn_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + } + + if (is_full) { + UPLOAD_LAYER_TRIPLET(i, "self_attn.q_proj", q_w, q_s, q_b); + UPLOAD_LAYER_TRIPLET(i, "self_attn.k_proj", k_w, k_s, k_b); + UPLOAD_LAYER_TRIPLET(i, "self_attn.v_proj", v_w, v_s, v_b); + UPLOAD_LAYER_TRIPLET(i, "self_attn.o_proj", o_w, o_s, o_b); + { + char n[256]; + snprintf(n, sizeof(n), "model.layers.%d.self_attn.q_norm.weight", i); + model->layers[i].q_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.self_attn.k_norm.weight", i); + model->layers[i].k_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + } + } else { + UPLOAD_LAYER_TRIPLET(i, "linear_attn.in_proj_qkv", qkv_w, qkv_s, qkv_b); + UPLOAD_LAYER_TRIPLET(i, "linear_attn.in_proj_z", z_w, z_s, z_b); + UPLOAD_LAYER_TRIPLET(i, "linear_attn.in_proj_b", b_w, b_s, b_b); + UPLOAD_LAYER_TRIPLET(i, "linear_attn.in_proj_a", a_w, a_s, a_b); + UPLOAD_LAYER_TRIPLET(i, "linear_attn.out_proj", out_proj_w, out_proj_s, out_proj_b); + { + char n[256]; + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.conv1d.weight", i); + model->layers[i].conv1d_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.A_log", i); + model->layers[i].A_log = (float *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.dt_bias", i); + model->layers[i].dt_bias = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + snprintf(n, sizeof(n), "model.layers.%d.linear_attn.norm.weight", i); + model->layers[i].gated_norm_w = (uint16_t *)upload_tensor(wf, n, model->d_weights, &off); + } + } + + // MoE routing + shared expert (all layers) + UPLOAD_LAYER_TRIPLET(i, "mlp.gate", gate_w, gate_s, gate_b); + UPLOAD_LAYER_TRIPLET(i, "mlp.shared_expert.gate_proj", sg_w, sg_s, sg_b); + UPLOAD_LAYER_TRIPLET(i, "mlp.shared_expert.up_proj", su_w, su_s, su_b); + UPLOAD_LAYER_TRIPLET(i, "mlp.shared_expert.down_proj", sd_w, sd_s, sd_b); + UPLOAD_LAYER_TRIPLET(i, "mlp.shared_expert_gate", seg_w, seg_s, seg_b); + } + } + + printf("[init] Uploaded %.2f GB in %.1f ms (offset=%zu)\n", + off / (1024.0*1024*1024), now_ms() - t0, off); + + // Allocate scratch buffers + CHECK_CUDA(cudaMalloc(&model->buf_hidden, HIDDEN_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_normed, HIDDEN_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_residual, HIDDEN_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_attn_out, LINEAR_TOTAL_VALUE * sizeof(float))); + + int max_proj = NUM_ATTN_HEADS * HEAD_DIM * 2; // full attn q_proj + if (LINEAR_CONV_DIM > max_proj) max_proj = LINEAR_CONV_DIM; + // GGUF fused QKV: Q+gate + K + V = q_proj_dim + 2*kv_dim + if (g_quant_format == 1) { + int fused_qkv = NUM_ATTN_HEADS * HEAD_DIM * 2 + 2 * NUM_KV_HEADS * HEAD_DIM; + if (fused_qkv > max_proj) max_proj = fused_qkv; + } + CHECK_CUDA(cudaMalloc(&model->buf_q_proj, max_proj * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_k_proj, NUM_KV_HEADS * HEAD_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_v_proj, NUM_KV_HEADS * HEAD_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_z_proj, LINEAR_TOTAL_VALUE * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_beta_proj, LINEAR_NUM_V_HEADS * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_alpha_proj, LINEAR_NUM_V_HEADS * sizeof(float))); + + CHECK_CUDA(cudaMalloc(&model->buf_h_mid, HIDDEN_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_gate_scores, NUM_EXPERTS * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_shared_gate, SHARED_INTERMEDIATE * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_shared_up, SHARED_INTERMEDIATE * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_shared_out, HIDDEN_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_expert_outs, MAX_K * HIDDEN_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_expert_data, MAX_K * g_expert_size)); + + // Linear attention persistent state + for (int i = 0; i < NUM_LAYERS; i++) { + if (!model->layers[i].is_full) { + CHECK_CUDA(cudaMalloc(&model->delta_state[i], + LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM * sizeof(float))); + CHECK_CUDA(cudaMemset(model->delta_state[i], 0, + LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->conv_state[i], + (CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM * sizeof(float))); + CHECK_CUDA(cudaMemset(model->conv_state[i], 0, + (CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM * sizeof(float))); + } + } + CHECK_CUDA(cudaMalloc(&model->buf_conv_output, LINEAR_CONV_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_g_decay, LINEAR_NUM_V_HEADS * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_beta_gate, LINEAR_NUM_V_HEADS * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_delta_output, LINEAR_TOTAL_VALUE * sizeof(float))); + + // Full attention KV caches + int kv_size = MAX_SEQ_LEN * NUM_KV_HEADS * HEAD_DIM; + for (int i = 0; i < NUM_LAYERS; i++) { + if (model->layers[i].is_full) { + CHECK_CUDA(cudaMalloc(&model->kv_k[i], kv_size * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->kv_v[i], kv_size * sizeof(float))); + CHECK_CUDA(cudaMemset(model->kv_k[i], 0, kv_size * sizeof(float))); + CHECK_CUDA(cudaMemset(model->kv_v[i], 0, kv_size * sizeof(float))); + model->kv_len[i] = 0; + } + } + CHECK_CUDA(cudaMalloc(&model->buf_attn_scores, NUM_ATTN_HEADS * MAX_SEQ_LEN * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_q, NUM_ATTN_HEADS * HEAD_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&model->buf_q_gate, NUM_ATTN_HEADS * HEAD_DIM * sizeof(float))); + + // Logits + CHECK_CUDA(cudaMalloc(&model->buf_logits, VOCAB_SIZE * sizeof(float))); + CHECK_CUDA(cudaMallocHost(&model->h_logits, VOCAB_SIZE * sizeof(float))); + + // Pre-allocated expert weights buffer + CHECK_CUDA(cudaMalloc(&model->buf_expert_weights, MAX_K * sizeof(float))); + + // CUDA streams for I/O overlap + CHECK_CUDA(cudaStreamCreate(&model->stream_compute)); + CHECK_CUDA(cudaStreamCreate(&model->stream_transfer)); + + // Expert staging (pinned host) + for (int i = 0; i < K; i++) + CHECK_CUDA(cudaMallocHost(&model->h_expert_buf[i], g_expert_size)); + + // Open expert files + for (int i = 0; i < NUM_LAYERS; i++) { + char path[512]; + snprintf(path, sizeof(path), "%s/layer_%02d.bin", expert_dir, i); + model->expert_fds[i] = open(path, O_RDONLY); + if (model->expert_fds[i] < 0) { + fprintf(stderr, "WARNING: Cannot open %s: %s\n", path, strerror(errno)); + } + } + + // GDS vs page cache: pread with page cache is faster for sustained generation + // because hot experts stay in RAM (~3ms vs 5.3ms). GDS bypasses page cache. + // Use --gds flag or ENABLE_GDS=1 env var to force GDS (useful if RAM < 32GB). + model->gds_available = 0; + int want_gds = (getenv("ENABLE_GDS") != NULL); + CUfileError_t gds_status; + if (!want_gds) { gds_status.err = (CUfileOpError)999; } + else { gds_status = cuFileDriverOpen(); } + if (gds_status.err == CU_FILE_SUCCESS) { + int gds_ok = 1; + for (int i = 0; i < NUM_LAYERS; i++) { + if (model->expert_fds[i] < 0) continue; + // Re-open with O_DIRECT for GDS + char path[512]; + snprintf(path, sizeof(path), "%s/layer_%02d.bin", expert_dir, i); + int dfd = open(path, O_RDONLY | O_DIRECT); + if (dfd < 0) { gds_ok = 0; break; } + + CUfileDescr_t desc = {}; + desc.handle.fd = dfd; + desc.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + CUfileError_t s = cuFileHandleRegister(&model->gds_handles[i], &desc); + if (s.err != CU_FILE_SUCCESS) { close(dfd); gds_ok = 0; break; } + } + if (gds_ok) { + // Register expert data buffer for GDS + cuFileBufRegister(model->buf_expert_data, MAX_K * g_expert_size, 0); + model->gds_available = 1; + printf("[init] GDS: enabled (direct SSD→GPU, set ENABLE_GDS=1)\n"); + } else { + printf("[init] Using pread + page cache (best for 32GB+ RAM)\n"); + cuFileDriverClose(); + } + } else { + printf("[init] Using pread + page cache (set ENABLE_GDS=1 to force GDS)\n"); + } + + // ---- VRAM expert cache ---- + // Use most of remaining VRAM for caching hot experts. + // Reserve 512MB for safety, use the rest. + // Set DISABLE_VRAM_CACHE=1 to disable. + { + int skip_cache = (getenv("DISABLE_VRAM_CACHE") != NULL); + size_t free_mem, total_mem; + CHECK_CUDA(cudaMemGetInfo(&free_mem, &total_mem)); + size_t reserve = 1024ULL * 1024 * 1024; // keep 1GB free for safety + size_t cache_bytes = (free_mem > reserve && !skip_cache) ? free_mem - reserve : 0; + model->vram_cache_capacity = (int)(cache_bytes / g_expert_size); + // Cap at total expert count (no point caching more than exist) + int total_experts = NUM_LAYERS * NUM_EXPERTS; + if (model->vram_cache_capacity > total_experts) + model->vram_cache_capacity = total_experts; + if (model->vram_cache_capacity > 0) { + size_t alloc = (size_t)model->vram_cache_capacity * g_expert_size; + CHECK_CUDA(cudaMalloc(&model->vram_cache_pool, alloc)); + model->cache_slots = (decltype(model->cache_slots))calloc( + model->vram_cache_capacity, sizeof(model->cache_slots[0])); + for (int i = 0; i < model->vram_cache_capacity; i++) { + model->cache_slots[i].layer = -1; + model->cache_slots[i].expert_id = -1; + } + memset(model->cache_map, -1, sizeof(model->cache_map)); + model->vram_cache_used = 0; + model->vram_cache_clock = 0; + printf("[init] VRAM expert cache: %d experts (%.1f GB), %.1f%% of total\n", + model->vram_cache_capacity, + alloc / (1024.0*1024*1024), + 100.0 * model->vram_cache_capacity / (NUM_LAYERS * NUM_EXPERTS)); + } else { + printf("[init] VRAM expert cache: disabled%s\n", skip_cache ? " by env" : " (not enough VRAM)"); + } + } + + // Print GPU memory usage + { + size_t free_mem, total_mem; + CHECK_CUDA(cudaMemGetInfo(&free_mem, &total_mem)); + printf("[init] GPU memory: %.2f GB used, %.2f GB free / %.2f GB total\n", + (total_mem - free_mem) / (1024.0*1024*1024), + free_mem / (1024.0*1024*1024), + total_mem / (1024.0*1024*1024)); + } + + return model; +} + +// ============================================================================ +// Embedding lookup (GPU dequant one row) +// ============================================================================ + +static void embed_token(Model *model, int token_id) { + if (g_quant_format == 1) { + // GGUF embedding: use a one-hot matvec to extract the row. + // The embedding tensor is stored as a GGUF quantized matrix [vocab_size, hidden_dim]. + // For a single row extraction we do a 1-element "matvec" trick: + // Actually, for GGUF the embedding is often F32. Use one-hot vector approach. + // Simpler: copy the raw row data to host, dequant if needed. + // For F32 embedding: row is at offset token_id * HIDDEN_DIM * sizeof(float) + int etype = model->qt_embed; + if (etype == 0) { + // F32 embedding: direct copy + float *embed_ptr = (float *)model->embed_w + (size_t)token_id * HIDDEN_DIM; + CHECK_CUDA(cudaMemcpy(model->buf_hidden, embed_ptr, + HIDDEN_DIM * sizeof(float), cudaMemcpyDeviceToDevice)); + } else { + // Quantized embedding: extract row via 1-row matvec + // embed is stored as [vocab_size rows, hidden_dim cols] in quantized blocks. + // A single row = hidden_dim elements = (hidden_dim/QK_K) blocks. + // We extract it by pointing the matvec at the specific row and running + // a 1-row x hidden_dim matvec with a ones vector. + // Simpler: use matvec with out_dim=1, treating the row as a 1-row matrix, + // and input as ones... No, that computes a dot product. + // + // Correct approach: extract raw quantized row to CPU, dequantize, upload. + size_t blocks_per_row, block_size; + if (etype == 12) { blocks_per_row = HIDDEN_DIM / 256; block_size = 144; } // Q4_K + else if (etype == 13) { blocks_per_row = HIDDEN_DIM / 256; block_size = 176; } // Q5_K + else if (etype == 14) { blocks_per_row = HIDDEN_DIM / 256; block_size = 210; } // Q6_K + else { blocks_per_row = HIDDEN_DIM / 256; block_size = 144; } // fallback Q4_K + + size_t row_bytes = blocks_per_row * block_size; + uint8_t *row_ptr = (uint8_t *)model->embed_w + (size_t)token_id * row_bytes; + + // Use a 1-row matvec with identity input to dequantize + // Set up a unit vector [1,1,1,...,1] of size HIDDEN_DIM and do matvec + // with the single row as a [1, HIDDEN_DIM] matrix → output is 1 float (wrong) + // + // Actually the correct trick: treat this as a [HIDDEN_DIM, 1] matrix + // with input [1.0] → output[HIDDEN_DIM]. But quantized blocks operate on + // groups, not individual elements. So we need a dedicated dequant-row kernel. + // + // Simplest correct approach for now: dequant on CPU, upload result. + float h_embed[HIDDEN_DIM]; + uint8_t *h_row = (uint8_t *)malloc(row_bytes); + CHECK_CUDA(cudaMemcpy(h_row, row_ptr, row_bytes, cudaMemcpyDeviceToHost)); + + // Dequantize Q4_K row on CPU + if (etype == 12) { + for (size_t bi = 0; bi < blocks_per_row; bi++) { + const uint8_t *block = h_row + bi * 144; + uint16_t d_raw, dmin_raw; + memcpy(&d_raw, block, 2); + memcpy(&dmin_raw, block + 2, 2); + // fp16 to float (CPU-side conversion) + // IEEE 754 half: sign(1) exp(5) mantissa(10) + auto fp16_to_f32 = [](uint16_t h) -> float { + uint32_t sign = (h >> 15) & 1; + uint32_t exp = (h >> 10) & 0x1F; + uint32_t mant = h & 0x3FF; + if (exp == 0) { + if (mant == 0) return sign ? -0.0f : 0.0f; + // subnormal + float val = ldexpf((float)mant, -24); + return sign ? -val : val; + } + if (exp == 31) return sign ? -INFINITY : INFINITY; + float val = ldexpf((float)(mant + 1024), (int)exp - 25); + return sign ? -val : val; + }; + float d_val = fp16_to_f32(d_raw); + float dmin_val = fp16_to_f32(dmin_raw); + const uint8_t *sc = block + 4; + const uint8_t *qs = block + 16; + + // Q4_K dequant matching GGML dequantize_row_q4_K: + // 4 iterations of 64 elements (2 sub-blocks per iteration) + // Low 32 nibbles use scale[2j], high 32 nibbles use scale[2j+1] + for (int j = 0; j < 4; j++) { + // Get scales for this pair of sub-blocks + int is0 = 2 * j, is1 = 2 * j + 1; + float d1, m1, d2, m2; + // Scale unpacking (same get_scale_min_k4 logic) + if (is0 < 4) { d1 = d_val*(sc[is0]&63); m1 = dmin_val*(sc[is0+4]&63); } + else { d1 = d_val*((sc[is0+4]&0xF)|((sc[is0-4]>>6)<<4)); m1 = dmin_val*((sc[is0+4]>>4)|((sc[is0]>>6)<<4)); } + if (is1 < 4) { d2 = d_val*(sc[is1]&63); m2 = dmin_val*(sc[is1+4]&63); } + else { d2 = d_val*((sc[is1+4]&0xF)|((sc[is1-4]>>6)<<4)); m2 = dmin_val*((sc[is1+4]>>4)|((sc[is1]>>6)<<4)); } + + const uint8_t *q = qs + 32 * j; + float *y = h_embed + bi * 256 + 64 * j; + for (int l = 0; l < 32; l++) { + y[l + 0] = d1 * (float)(q[l] & 0xF) - m1; + y[l + 32] = d2 * (float)(q[l] >> 4) - m2; + } + } + } + } else { + // Fallback: zero embedding (will produce garbage but won't crash) + memset(h_embed, 0, sizeof(h_embed)); + } + free(h_row); + CHECK_CUDA(cudaMemcpy(model->buf_hidden, h_embed, + HIDDEN_DIM * sizeof(float), cudaMemcpyHostToDevice)); + } + } else { + // MLX affine 4-bit embedding + uint32_t packed_cols = HIDDEN_DIM / 8; // 512 + uint32_t num_groups = HIDDEN_DIM / GROUP_SIZE_C; // 64 + + uint32_t *W = model->embed_w + token_id * packed_cols; + uint16_t *S = model->embed_s + token_id * num_groups; + uint16_t *B = model->embed_b + token_id * num_groups; + + // CPU dequant (embedding is a one-time cost per token): + uint32_t h_W[512]; + uint16_t h_S[64], h_B[64]; + CHECK_CUDA(cudaMemcpy(h_W, W, packed_cols * sizeof(uint32_t), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(h_S, S, num_groups * sizeof(uint16_t), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(h_B, B, num_groups * sizeof(uint16_t), cudaMemcpyDeviceToHost)); + + float h_out[HIDDEN_DIM]; + for (uint32_t g = 0; g < num_groups; g++) { + float scale = bf16_to_f32_host(h_S[g]); + float bias = bf16_to_f32_host(h_B[g]); + uint32_t base = g * (GROUP_SIZE_C / 8); + for (uint32_t p = 0; p < GROUP_SIZE_C / 8; p++) { + uint32_t packed = h_W[base + p]; + for (int n = 0; n < 8; n++) { + uint32_t nibble = (packed >> (n * 4)) & 0xF; + h_out[g * GROUP_SIZE_C + p * 8 + n] = (float)nibble * scale + bias; + } + } + } + + CHECK_CUDA(cudaMemcpy(model->buf_hidden, h_out, HIDDEN_DIM * sizeof(float), cudaMemcpyHostToDevice)); + } +} + +// ============================================================================ +// Expert I/O: parallel pread + cudaMemcpy +// ============================================================================ + +typedef struct { + int fd; + void *buf; + size_t size; + off_t offset; +} PreadArg; + +static void *pread_worker(void *arg) { + PreadArg *a = (PreadArg *)arg; + (void)pread(a->fd, a->buf, a->size, a->offset); + return NULL; +} + +// GDS expert loading: direct SSD→GPU +typedef struct { + CUfileHandle_t handle; + void *d_buf; + size_t size; + off_t offset; +} GDSArg; + +static void *gds_worker(void *arg) { + GDSArg *a = (GDSArg *)arg; + cuFileRead(a->handle, a->d_buf, a->size, a->offset, 0); + return NULL; +} + +static void load_experts(Model *model, int layer_idx, const int *expert_ids, int K) { + if (model->gds_available) { + // GDS path: parallel cuFileRead directly to GPU memory + pthread_t threads[MAX_K]; + GDSArg args[MAX_K]; + for (int i = 0; i < K; i++) { + args[i].handle = model->gds_handles[layer_idx]; + args[i].d_buf = (char *)model->buf_expert_data + i * g_expert_size; + args[i].size = g_expert_size; + args[i].offset = (off_t)expert_ids[i] * g_expert_size; + pthread_create(&threads[i], NULL, gds_worker, &args[i]); + } + for (int i = 0; i < K; i++) + pthread_join(threads[i], NULL); + } else { + // Fallback: parallel pread → pinned host → cudaMemcpyAsync + pthread_t threads[MAX_K]; + PreadArg args[MAX_K]; + int fd = model->expert_fds[layer_idx]; + for (int i = 0; i < K; i++) { + args[i].fd = fd; + args[i].buf = model->h_expert_buf[i]; + args[i].size = g_expert_size; + args[i].offset = (off_t)expert_ids[i] * g_expert_size; + pthread_create(&threads[i], NULL, pread_worker, &args[i]); + } + for (int i = 0; i < K; i++) + pthread_join(threads[i], NULL); + // Async copy to GPU + for (int i = 0; i < K; i++) { + CHECK_CUDA(cudaMemcpyAsync( + (char *)model->buf_expert_data + i * g_expert_size, + model->h_expert_buf[i], g_expert_size, + cudaMemcpyHostToDevice, model->stream_transfer)); + } + CHECK_CUDA(cudaStreamSynchronize(model->stream_transfer)); + } +} + +// ============================================================================ +// Expert forward pass (one expert on GPU) +// ============================================================================ + +// Expert component offsets — computed from model dimensions +// Layout: gate(W,S,B) + up(W,S,B) + down(W,S,B) +// W: [out, in/8] uint32, S: [out, in/64] bf16, B: [out, in/64] bf16 +#define EXP_GATE_W_SZ (MOE_INTERMEDIATE * (HIDDEN_DIM / 8) * 4) +#define EXP_GATE_S_SZ (MOE_INTERMEDIATE * (HIDDEN_DIM / GROUP_SIZE_C) * 2) +#define EXP_GATE_B_SZ EXP_GATE_S_SZ +#define EXP_UP_W_SZ EXP_GATE_W_SZ +#define EXP_UP_S_SZ EXP_GATE_S_SZ +#define EXP_UP_B_SZ EXP_GATE_S_SZ +#define EXP_DOWN_W_SZ (HIDDEN_DIM * (MOE_INTERMEDIATE / 8) * 4) +#define EXP_DOWN_S_SZ (HIDDEN_DIM * (MOE_INTERMEDIATE / GROUP_SIZE_C) * 2) +#define EXP_DOWN_B_SZ EXP_DOWN_S_SZ + +#define EXP_GATE_W 0 +#define EXP_GATE_S (EXP_GATE_W + EXP_GATE_W_SZ) +#define EXP_GATE_B (EXP_GATE_S + EXP_GATE_S_SZ) +#define EXP_UP_W (EXP_GATE_B + EXP_GATE_B_SZ) +#define EXP_UP_S (EXP_UP_W + EXP_UP_W_SZ) +#define EXP_UP_B (EXP_UP_S + EXP_UP_S_SZ) +#define EXP_DOWN_W (EXP_UP_B + EXP_UP_B_SZ) +#define EXP_DOWN_S (EXP_DOWN_W + EXP_DOWN_W_SZ) +#define EXP_DOWN_B (EXP_DOWN_S + EXP_DOWN_S_SZ) + +// ============================================================================ +// Format-aware RMS norm — GGUF uses f32 weights, MLX uses bf16 +// ============================================================================ + +static inline void do_rms_norm(const float *x, const void *w, float *out, + uint32_t dim, float eps, cudaStream_t s = 0) { + if (g_quant_format == 1) { + // GGUF: norm weights are F32 + rms_norm<<<1, 256, 0, s>>>((const float *)x, (const float *)w, out, dim, eps); + } else { + // MLX: norm weights are BF16 + launch_rms_norm_bf16(x, (const uint16_t *)w, out, dim, eps, s); + } +} + +// ============================================================================ +// Format-aware matvec wrapper — dispatches MLX or GGUF kernel +// ============================================================================ + +static inline void do_matvec( + const uint32_t *W, const uint16_t *S, const uint16_t *B, + const float *x, float *out, uint32_t out_dim, uint32_t in_dim, + int gguf_type, cudaStream_t stream = 0 +) { + if (g_quant_format == 1) { + launch_dequant_matvec_gguf((const void *)W, x, out, out_dim, in_dim, gguf_type, stream); + } else { + launch_dequant_matvec(W, S, B, x, out, out_dim, in_dim, stream); + } +} + +static void expert_forward(Model *model, int expert_slot, int layer_idx, const float *input, float *output) { + if (g_quant_format == 1) { + uint8_t *base = (uint8_t *)model->buf_expert_data + expert_slot * g_expert_size; + launch_dequant_matvec_gguf((const void *)(base + g_gguf_gate_offset), + input, model->buf_shared_gate, MOE_INTERMEDIATE, HIDDEN_DIM, + g_gguf_gate_type); + launch_dequant_matvec_gguf((const void *)(base + g_gguf_up_offset), + input, model->buf_shared_up, MOE_INTERMEDIATE, HIDDEN_DIM, + g_gguf_up_type); + launch_swiglu(model->buf_shared_gate, model->buf_shared_up, model->buf_shared_gate, + MOE_INTERMEDIATE); + launch_dequant_matvec_gguf((const void *)(base + g_gguf_down_offset), + model->buf_shared_gate, output, HIDDEN_DIM, MOE_INTERMEDIATE, + g_gguf_down_type_per_layer[layer_idx]); + } else { + // MLX affine 4-bit layout + void *base = (char *)model->buf_expert_data + expert_slot * g_expert_size; + + uint32_t *gate_w = (uint32_t *)((char *)base + EXP_GATE_W); + uint16_t *gate_s = (uint16_t *)((char *)base + EXP_GATE_S); + uint16_t *gate_b = (uint16_t *)((char *)base + EXP_GATE_B); + uint32_t *up_w = (uint32_t *)((char *)base + EXP_UP_W); + uint16_t *up_s = (uint16_t *)((char *)base + EXP_UP_S); + uint16_t *up_b = (uint16_t *)((char *)base + EXP_UP_B); + uint32_t *down_w = (uint32_t *)((char *)base + EXP_DOWN_W); + uint16_t *down_s = (uint16_t *)((char *)base + EXP_DOWN_S); + uint16_t *down_b = (uint16_t *)((char *)base + EXP_DOWN_B); + + // gate_proj: [MOE_INTERMEDIATE, HIDDEN_DIM] → buf_shared_gate + launch_dequant_matvec(gate_w, gate_s, gate_b, input, model->buf_shared_gate, + MOE_INTERMEDIATE, HIDDEN_DIM); + // up_proj: [MOE_INTERMEDIATE, HIDDEN_DIM] → buf_shared_up + launch_dequant_matvec(up_w, up_s, up_b, input, model->buf_shared_up, + MOE_INTERMEDIATE, HIDDEN_DIM); + // SwiGLU + launch_swiglu(model->buf_shared_gate, model->buf_shared_up, model->buf_shared_gate, + MOE_INTERMEDIATE); + // down_proj: [HIDDEN_DIM, MOE_INTERMEDIATE] → output + launch_dequant_matvec(down_w, down_s, down_b, model->buf_shared_gate, output, + HIDDEN_DIM, MOE_INTERMEDIATE); + } +} + +// ============================================================================ +// CPU-side routing: softmax + topK +// ============================================================================ + +static void cpu_softmax(float *x, int n) { + float mx = x[0]; + for (int i = 1; i < n; i++) if (x[i] > mx) mx = x[i]; + float sum = 0.0f; + for (int i = 0; i < n; i++) { x[i] = expf(x[i] - mx); sum += x[i]; } + for (int i = 0; i < n; i++) x[i] /= sum; +} + +static void topk(const float *scores, int n, int k, int *indices, float *weights) { + // Simple selection sort for small k + uint8_t *used = (uint8_t *)calloc(n, 1); + for (int j = 0; j < k; j++) { + int best = -1; + float best_val = -1e30f; + for (int i = 0; i < n; i++) { + if (!used[i] && scores[i] > best_val) { + best_val = scores[i]; + best = i; + } + } + indices[j] = best; + weights[j] = best_val; + if (best >= 0) used[best] = 1; + } + free(used); + + // Renormalize weights + float sum = 0.0f; + for (int j = 0; j < k; j++) sum += weights[j]; + if (sum > 0) for (int j = 0; j < k; j++) weights[j] /= sum; +} + +// ============================================================================ +// CPU-side RoPE for full attention +// ============================================================================ + +static void apply_rope(float *q, float *k, int pos) { + int half = ROTARY_DIM / 2; + for (int h = 0; h < NUM_ATTN_HEADS; h++) { + float *qh = q + h * HEAD_DIM; + for (int i = 0; i < half; i++) { + float freq = 1.0f / powf(ROPE_THETA, (float)(2 * i) / ROTARY_DIM); + float angle = (float)pos * freq; + float c = cosf(angle), s = sinf(angle); + float q0 = qh[i], q1 = qh[i + half]; + qh[i] = q0 * c - q1 * s; + qh[i + half] = q0 * s + q1 * c; + } + } + for (int h = 0; h < NUM_KV_HEADS; h++) { + float *kh = k + h * HEAD_DIM; + for (int i = 0; i < half; i++) { + float freq = 1.0f / powf(ROPE_THETA, (float)(2 * i) / ROTARY_DIM); + float angle = (float)pos * freq; + float c = cosf(angle), s = sinf(angle); + float k0 = kh[i], k1 = kh[i + half]; + kh[i] = k0 * c - k1 * s; + kh[i + half] = k0 * s + k1 * c; + } + } +} + +// ============================================================================ +// Per-layer forward pass +// ============================================================================ + +// Expert logging for routing analysis (set EXPERT_LOG=/path to enable) +static FILE *g_expert_log = NULL; + +// Timing accumulator for per-phase breakdown +static int g_timing_enabled = 0; +static int g_dump_layer0 = 0; // set to 1 for GGUF at startup +static struct { + double input_norm, attn_proj, attn_compute, oproj_residual; + double routing, shared_expert, expert_io, expert_compute, combine; + double total; + int count; +} g_layer_timing; + +// One-shot Q5_K kernel verification against Python reference +static void verify_q5k_kernel(Model *model) { + if (g_quant_format != 1) return; + // Layer 0 QKV is Q5_K. Run matvec with ones input, check output. + uint32_t *qkv_w = model->layers[0].qkv_w; + if (!qkv_w) return; + float *d_ones, *d_out; + CHECK_CUDA(cudaMalloc(&d_ones, HIDDEN_DIM * sizeof(float))); + CHECK_CUDA(cudaMalloc(&d_out, LINEAR_CONV_DIM * sizeof(float))); + float *h_ones = (float*)malloc(HIDDEN_DIM * sizeof(float)); + for (int i = 0; i < HIDDEN_DIM; i++) h_ones[i] = 1.0f; + CHECK_CUDA(cudaMemcpy(d_ones, h_ones, HIDDEN_DIM * sizeof(float), cudaMemcpyHostToDevice)); + launch_dequant_matvec_q5k((const uint8_t*)qkv_w, d_ones, d_out, + LINEAR_CONV_DIM, HIDDEN_DIM); + CHECK_CUDA(cudaDeviceSynchronize()); + float h_out[5]; + CHECK_CUDA(cudaMemcpy(h_out, d_out, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[verify] Q5_K kernel (ones input) out[0..4] = %.6f %.6f %.6f %.6f %.6f\n", + h_out[0], h_out[1], h_out[2], h_out[3], h_out[4]); + printf("[verify] Python reference: -0.327011 1.785097 0.084108 0.016506 -1.251458\n"); + + // Also verify Q6_K with shared expert down_proj (layer 0) + uint32_t *sd_w = model->layers[0].sd_w; + if (sd_w && model->layers[0].qt_sd == 14) { + float *d_ones6, *d_out6; + CHECK_CUDA(cudaMalloc(&d_ones6, SHARED_INTERMEDIATE * sizeof(float))); + CHECK_CUDA(cudaMalloc(&d_out6, HIDDEN_DIM * sizeof(float))); + float *h6 = (float*)malloc(SHARED_INTERMEDIATE * sizeof(float)); + for (int i = 0; i < SHARED_INTERMEDIATE; i++) h6[i] = 1.0f; + CHECK_CUDA(cudaMemcpy(d_ones6, h6, SHARED_INTERMEDIATE * sizeof(float), cudaMemcpyHostToDevice)); + launch_dequant_matvec_q6k((const uint8_t*)sd_w, d_ones6, d_out6, HIDDEN_DIM, SHARED_INTERMEDIATE); + CHECK_CUDA(cudaDeviceSynchronize()); + float ho[5]; + CHECK_CUDA(cudaMemcpy(ho, d_out6, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[verify] Q6_K kernel (ones) out[0..4] = %.6f %.6f %.6f %.6f %.6f\n", + ho[0], ho[1], ho[2], ho[3], ho[4]); + printf("[verify] Python reference: -0.736804 -0.150297 -0.063833 0.046720 -0.036440\n"); + free(h6); CHECK_CUDA(cudaFree(d_ones6)); CHECK_CUDA(cudaFree(d_out6)); + } + free(h_ones); + CHECK_CUDA(cudaFree(d_ones)); + CHECK_CUDA(cudaFree(d_out)); +} + +static void layer_forward(Model *model, int layer_idx, int pos, int K) { + auto &L = model->layers[layer_idx]; + double t0, t1; + + if (g_timing_enabled) { CHECK_CUDA(cudaDeviceSynchronize()); t0 = now_ms(); } + + // 1. Input RMS norm + do_rms_norm(model->buf_hidden, L.input_norm_w, model->buf_normed, + HIDDEN_DIM, RMS_NORM_EPS); + + // Save residual + CHECK_CUDA(cudaMemcpy(model->buf_residual, model->buf_hidden, + HIDDEN_DIM * sizeof(float), cudaMemcpyDeviceToDevice)); + + if (g_timing_enabled) { CHECK_CUDA(cudaDeviceSynchronize()); t1 = now_ms(); g_layer_timing.input_norm += t1-t0; t0=t1; } + + // 2. Attention projections + attention compute + if (L.is_full) { + // Full attention path + int q_proj_dim = NUM_ATTN_HEADS * HEAD_DIM * 2; // interleaved Q + gate + int kv_dim = NUM_KV_HEADS * HEAD_DIM; + + if (g_quant_format == 1 && L.full_qkv_w) { + // GGUF: fused QKV tensor → single matvec, then split output + int fused_dim = q_proj_dim + 2 * kv_dim; + do_matvec(L.full_qkv_w, NULL, NULL, model->buf_normed, + model->buf_q_proj, fused_dim, HIDDEN_DIM, + L.qt_full_qkv); + // Split: buf_q_proj[0..q_proj_dim) = Q+gate, + // buf_q_proj[q_proj_dim..q_proj_dim+kv_dim) = K, + // buf_q_proj[q_proj_dim+kv_dim..) = V + CHECK_CUDA(cudaMemcpy(model->buf_k_proj, (float *)model->buf_q_proj + q_proj_dim, + kv_dim * sizeof(float), cudaMemcpyDeviceToDevice)); + CHECK_CUDA(cudaMemcpy(model->buf_v_proj, (float *)model->buf_q_proj + q_proj_dim + kv_dim, + kv_dim * sizeof(float), cudaMemcpyDeviceToDevice)); + } else { + // MLX: separate Q, K, V projections + do_matvec(L.q_w, L.q_s, L.q_b, model->buf_normed, + model->buf_q_proj, q_proj_dim, HIDDEN_DIM, L.qt_q); + do_matvec(L.k_w, L.k_s, L.k_b, model->buf_normed, + model->buf_k_proj, kv_dim, HIDDEN_DIM, L.qt_k); + do_matvec(L.v_w, L.v_s, L.v_b, model->buf_normed, + model->buf_v_proj, kv_dim, HIDDEN_DIM, L.qt_v); + } + CHECK_CUDA(cudaDeviceSynchronize()); + + // Deinterleave Q and Q_gate, apply Q/K norms, RoPE, attention — on CPU + // (CPU is fine since it's only 15 layers and attention is memory-bound at low seq_len) + float h_q_proj[NUM_ATTN_HEADS * HEAD_DIM * 2]; + float h_k[NUM_KV_HEADS * HEAD_DIM]; + float h_v[NUM_KV_HEADS * HEAD_DIM]; + CHECK_CUDA(cudaMemcpy(h_q_proj, model->buf_q_proj, q_proj_dim * sizeof(float), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(h_k, model->buf_k_proj, kv_dim * sizeof(float), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(h_v, model->buf_v_proj, kv_dim * sizeof(float), cudaMemcpyDeviceToHost)); + + // Deinterleave: q_proj is [num_heads, 2*head_dim] → split into q[num_heads, head_dim] + gate + float h_q[NUM_ATTN_HEADS * HEAD_DIM]; + float h_qg[NUM_ATTN_HEADS * HEAD_DIM]; + for (int h = 0; h < NUM_ATTN_HEADS; h++) { + memcpy(h_q + h * HEAD_DIM, h_q_proj + h * 2 * HEAD_DIM, HEAD_DIM * sizeof(float)); + memcpy(h_qg + h * HEAD_DIM, h_q_proj + h * 2 * HEAD_DIM + HEAD_DIM, HEAD_DIM * sizeof(float)); + } + + // Q/K RMS norm with learned weights + uint16_t h_qnorm[HEAD_DIM], h_knorm[HEAD_DIM]; + CHECK_CUDA(cudaMemcpy(h_qnorm, L.q_norm_w, HEAD_DIM * sizeof(uint16_t), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(h_knorm, L.k_norm_w, HEAD_DIM * sizeof(uint16_t), cudaMemcpyDeviceToHost)); + + for (int h = 0; h < NUM_ATTN_HEADS; h++) { + float *qh = h_q + h * HEAD_DIM; + float sum_sq = 0; + for (int d = 0; d < HEAD_DIM; d++) sum_sq += qh[d] * qh[d]; + float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS); + for (int d = 0; d < HEAD_DIM; d++) qh[d] *= inv_rms * bf16_to_f32_host(h_qnorm[d]); + } + for (int h = 0; h < NUM_KV_HEADS; h++) { + float *kh = h_k + h * HEAD_DIM; + float sum_sq = 0; + for (int d = 0; d < HEAD_DIM; d++) sum_sq += kh[d] * kh[d]; + float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS); + for (int d = 0; d < HEAD_DIM; d++) kh[d] *= inv_rms * bf16_to_f32_host(h_knorm[d]); + } + + // RoPE + apply_rope(h_q, h_k, pos); + + // KV cache update + int fa_idx = (layer_idx + 1) / FULL_ATTN_INTERVAL - 1; + int cache_pos = model->kv_len[layer_idx]; + CHECK_CUDA(cudaMemcpy( + model->kv_k[layer_idx] + cache_pos * kv_dim, + h_k, kv_dim * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy( + model->kv_v[layer_idx] + cache_pos * kv_dim, + h_v, kv_dim * sizeof(float), cudaMemcpyHostToDevice)); + model->kv_len[layer_idx]++; + + // Attention: Q@K^T → softmax → @V on GPU + int seq_len = model->kv_len[layer_idx]; + int heads_per_kv = NUM_ATTN_HEADS / NUM_KV_HEADS; + float scale = 1.0f / sqrtf((float)HEAD_DIM); + + CHECK_CUDA(cudaMemcpy(model->buf_q, h_q, NUM_ATTN_HEADS * HEAD_DIM * sizeof(float), cudaMemcpyHostToDevice)); + + attn_scores<<>>( + model->buf_q, model->kv_k[layer_idx], model->buf_attn_scores, + HEAD_DIM, kv_dim, seq_len, MAX_SEQ_LEN, scale, heads_per_kv, seq_len); + + attn_softmax<<>>( + model->buf_attn_scores, seq_len, MAX_SEQ_LEN); + + int attn_threads = NUM_ATTN_HEADS * HEAD_DIM; + attn_values<<<(attn_threads + 255) / 256, 256>>>( + model->buf_attn_scores, model->kv_v[layer_idx], model->buf_attn_out, + HEAD_DIM, kv_dim, seq_len, MAX_SEQ_LEN, heads_per_kv); + + // Sigmoid gate: attn_out *= sigmoid(q_gate) + CHECK_CUDA(cudaMemcpy(model->buf_q_gate, h_qg, + NUM_ATTN_HEADS * HEAD_DIM * sizeof(float), cudaMemcpyHostToDevice)); + sigmoid_gate<<<(attn_threads + 255) / 256, 256>>>( + model->buf_attn_out, model->buf_q_gate, attn_threads); + + // O projection + int oproj_in = NUM_ATTN_HEADS * HEAD_DIM; + do_matvec(L.o_w, L.o_s, L.o_b, model->buf_attn_out, + model->buf_h_mid, HIDDEN_DIM, oproj_in, L.qt_o); + + } else { + // Linear attention (GatedDeltaNet) path — all on GPU + do_matvec(L.qkv_w, L.qkv_s, L.qkv_b, model->buf_normed, + model->buf_q_proj, LINEAR_CONV_DIM, HIDDEN_DIM, L.qt_qkv); + do_matvec(L.z_w, L.z_s, L.z_b, model->buf_normed, + model->buf_z_proj, LINEAR_TOTAL_VALUE, HIDDEN_DIM, L.qt_z); + do_matvec(L.b_w, L.b_s, L.b_b, model->buf_normed, + model->buf_beta_proj, LINEAR_NUM_V_HEADS, HIDDEN_DIM, L.qt_b); + do_matvec(L.a_w, L.a_s, L.a_b, model->buf_normed, + model->buf_alpha_proj, LINEAR_NUM_V_HEADS, HIDDEN_DIM, L.qt_a); + + // Dump raw QKV before conv1d + if (layer_idx == 0 && g_dump_layer0) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + CHECK_CUDA(cudaMemcpy(d5, model->buf_q_proj, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] L0 raw_QKV[0:5] %12.6f %12.6f %12.6f %12.6f %12.6f\n", d5[0],d5[1],d5[2],d5[3],d5[4]); + CHECK_CUDA(cudaMemcpy(d5, model->buf_q_proj + 4096, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] L0 raw_V[0:5] %12.6f %12.6f %12.6f %12.6f %12.6f\n", d5[0],d5[1],d5[2],d5[3],d5[4]); + } + + // Conv1d step + conv1d_step<<<(LINEAR_CONV_DIM + 255) / 256, 256>>>( + model->conv_state[layer_idx], model->buf_q_proj, + L.conv1d_w, model->buf_conv_output, LINEAR_CONV_DIM); + + // Dump layer 0 intermediates for comparison with Python reference + if (layer_idx == 0 && g_dump_layer0) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + #define DUMP5(name, buf) do { \ + CHECK_CUDA(cudaMemcpy(d5, buf, 5*sizeof(float), cudaMemcpyDeviceToHost)); \ + printf("[ref] L0 %-15s %12.6f %12.6f %12.6f %12.6f %12.6f\n", name, d5[0],d5[1],d5[2],d5[3],d5[4]); \ + } while(0) + DUMP5("conv_Q[0:5]", model->buf_conv_output); + DUMP5("conv_K[0:5]", model->buf_conv_output + LINEAR_TOTAL_KEY); + DUMP5("conv_V[0:5]", model->buf_conv_output + 2 * LINEAR_TOTAL_KEY); + DUMP5("conv_V_h1[0:5]", model->buf_conv_output + 2 * LINEAR_TOTAL_KEY + 128); + DUMP5("z_proj", model->buf_z_proj); + DUMP5("alpha", model->buf_alpha_proj); + DUMP5("beta", model->buf_beta_proj); + #undef DUMP5 + } + + // Normalize Q and K + if (g_quant_format == 1) { + // GGUF: L2 normalization (matches llama.cpp ggml_l2_norm) + l2_norm_qk<<>>( + model->buf_conv_output, + model->buf_conv_output + LINEAR_TOTAL_KEY, + LINEAR_KEY_DIM); + } else { + // MLX: RMS norm with scaling (original 397B behavior) + float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM); + rms_norm_qk<<>>( + model->buf_conv_output, + model->buf_conv_output + LINEAR_TOTAL_KEY, + LINEAR_KEY_DIM, inv_scale); + } + + if (layer_idx == 0 && g_dump_layer0) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + // Dump Q/K after L2 norm and V (raw, not normalized) + float *h_q128 = (float*)malloc(128*sizeof(float)); + float *h_k128 = (float*)malloc(128*sizeof(float)); + float *h_v5 = (float*)malloc(5*sizeof(float)); + CHECK_CUDA(cudaMemcpy(h_q128, model->buf_conv_output, 128*sizeof(float), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(h_k128, model->buf_conv_output + LINEAR_TOTAL_KEY, 128*sizeof(float), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(h_v5, model->buf_conv_output + 2*LINEAR_TOTAL_KEY, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] L0 q_normed[0:5] %12.6f %12.6f %12.6f %12.6f %12.6f\n", h_q128[0],h_q128[1],h_q128[2],h_q128[3],h_q128[4]); + printf("[ref] L0 k_normed[0:5] %12.6f %12.6f %12.6f %12.6f %12.6f\n", h_k128[0],h_k128[1],h_k128[2],h_k128[3],h_k128[4]); + printf("[ref] L0 V_raw[0:5] %12.6f %12.6f %12.6f %12.6f %12.6f\n", h_v5[0],h_v5[1],h_v5[2],h_v5[3],h_v5[4]); + // Compute |q|^2 and |k|^2 for head 0 + float q_sq = 0, k_sq = 0, qk_dot = 0; + for (int i = 0; i < 128; i++) { q_sq += h_q128[i]*h_q128[i]; k_sq += h_k128[i]*h_k128[i]; qk_dot += h_q128[i]*h_k128[i]; } + printf("[ref] L0 |q|^2=%.6f |k|^2=%.6f q.k=%.6f\n", q_sq, k_sq, qk_dot); + free(h_q128); free(h_k128); free(h_v5); + } + + // Compute decay and beta gate + if (g_quant_format == 1) { + // GGUF: ssm_a and dt_bias are both F32 + compute_decay_beta_gguf<<<1, LINEAR_NUM_V_HEADS>>>( + model->buf_alpha_proj, model->buf_beta_proj, + L.A_log, (const float *)L.dt_bias, + model->buf_g_decay, model->buf_beta_gate); + } else { + compute_decay_beta<<<1, LINEAR_NUM_V_HEADS>>>( + model->buf_alpha_proj, model->buf_beta_proj, + L.A_log, L.dt_bias, + model->buf_g_decay, model->buf_beta_gate); + } + + if (layer_idx == 0 && g_dump_layer0) { + float dd[5], db[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + CHECK_CUDA(cudaMemcpy(dd, model->buf_g_decay, 4*sizeof(float), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(db, model->buf_beta_gate, 4*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] L0 %-15s %12.6f %12.6f %12.6f %12.6f\n", "decay", dd[0],dd[1],dd[2],dd[3]); + printf("[ref] L0 %-15s %12.6f %12.6f %12.6f %12.6f\n", "beta_gate", db[0],db[1],db[2],db[3]); + } + + // GatedDeltaNet recurrence + uint32_t khpv = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS; + gated_delta_net_step<<>>( + model->delta_state[layer_idx], + model->buf_conv_output, // q [2048] + model->buf_conv_output + LINEAR_TOTAL_KEY, // k [2048] + model->buf_conv_output + 2 * LINEAR_TOTAL_KEY, // v [8192] + model->buf_g_decay, model->buf_beta_gate, + model->buf_delta_output, khpv); + + if (layer_idx == 0 && g_dump_layer0) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + CHECK_CUDA(cudaMemcpy(d5, model->buf_delta_output, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] L0 %-15s %12.6f %12.6f %12.6f %12.6f %12.6f\n", "delta_out", d5[0],d5[1],d5[2],d5[3],d5[4]); + // Save full delta_out for per-head analysis + float *dt = (float*)malloc(LINEAR_TOTAL_VALUE * sizeof(float)); + CHECK_CUDA(cudaMemcpy(dt, model->buf_delta_output, LINEAR_TOTAL_VALUE * sizeof(float), cudaMemcpyDeviceToHost)); + FILE *df = fopen("/tmp/cuda_delta_out.bin", "wb"); + fwrite(dt, sizeof(float), LINEAR_TOTAL_VALUE, df); + fclose(df); free(dt); + } + + // Gated RMS norm + gated_rms_norm<<>>( + model->buf_delta_output, model->buf_z_proj, + L.gated_norm_w, model->buf_attn_out, + LINEAR_VALUE_DIM, RMS_NORM_EPS); + + if (layer_idx == 0 && g_dump_layer0) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + CHECK_CUDA(cudaMemcpy(d5, model->buf_attn_out, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] L0 %-15s %12.6f %12.6f %12.6f %12.6f %12.6f\n", "gated_norm", d5[0],d5[1],d5[2],d5[3],d5[4]); + // Save full gated_norm for comparison + float *gn = (float*)malloc(LINEAR_TOTAL_VALUE * sizeof(float)); + CHECK_CUDA(cudaMemcpy(gn, model->buf_attn_out, LINEAR_TOTAL_VALUE * sizeof(float), cudaMemcpyDeviceToHost)); + FILE *gf = fopen("/tmp/cuda_gated_norm.bin", "wb"); + fwrite(gn, sizeof(float), LINEAR_TOTAL_VALUE, gf); + fclose(gf); free(gn); + printf("[ref] Saved gated_norm (%d floats) to /tmp/cuda_gated_norm.bin\n", LINEAR_TOTAL_VALUE); + } + + // Output projection + do_matvec(L.out_proj_w, L.out_proj_s, L.out_proj_b, + model->buf_attn_out, model->buf_h_mid, + HIDDEN_DIM, LINEAR_TOTAL_VALUE, L.qt_out); + + if (layer_idx == 0 && g_dump_layer0) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + CHECK_CUDA(cudaMemcpy(d5, model->buf_h_mid, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] L0 %-15s %12.6f %12.6f %12.6f %12.6f %12.6f\n", "oproj_out", d5[0],d5[1],d5[2],d5[3],d5[4]); + g_dump_layer0 = 0; // only dump once + } + } + + if (g_timing_enabled) { CHECK_CUDA(cudaDeviceSynchronize()); t1 = now_ms(); g_layer_timing.attn_compute += t1-t0; t0=t1; } + + // 3. Residual + post-attention norm + launch_residual_add(model->buf_residual, model->buf_h_mid, model->buf_h_mid, HIDDEN_DIM); + do_rms_norm(model->buf_h_mid, L.post_attn_norm_w, model->buf_normed, + HIDDEN_DIM, RMS_NORM_EPS); + + if (g_timing_enabled) { CHECK_CUDA(cudaDeviceSynchronize()); t1 = now_ms(); g_layer_timing.oproj_residual += t1-t0; t0=t1; } + + // 4. MoE routing + do_matvec(L.gate_w, L.gate_s, L.gate_b, model->buf_normed, + model->buf_gate_scores, NUM_EXPERTS, HIDDEN_DIM, L.qt_gate); + CHECK_CUDA(cudaDeviceSynchronize()); + + float h_scores[NUM_EXPERTS]; + CHECK_CUDA(cudaMemcpy(h_scores, model->buf_gate_scores, NUM_EXPERTS * sizeof(float), cudaMemcpyDeviceToHost)); + cpu_softmax(h_scores, NUM_EXPERTS); + + int expert_ids[MAX_K]; + float expert_weights[MAX_K]; + topk(h_scores, NUM_EXPERTS, K, expert_ids, expert_weights); + + if (g_expert_log) + fprintf(g_expert_log, "%d %d %d %d %d\n", layer_idx, + expert_ids[0], expert_ids[1], expert_ids[2], expert_ids[3]); + + if (g_timing_enabled) { t1 = now_ms(); g_layer_timing.routing += t1-t0; t0=t1; } + + // 5. Shared expert forward + expert I/O OVERLAP + do_matvec(L.sg_w, L.sg_s, L.sg_b, model->buf_normed, + model->buf_shared_gate, SHARED_INTERMEDIATE, HIDDEN_DIM, L.qt_sg); + do_matvec(L.su_w, L.su_s, L.su_b, model->buf_normed, + model->buf_shared_up, SHARED_INTERMEDIATE, HIDDEN_DIM, L.qt_su); + launch_swiglu(model->buf_shared_gate, model->buf_shared_up, model->buf_shared_gate, + SHARED_INTERMEDIATE); + do_matvec(L.sd_w, L.sd_s, L.sd_b, model->buf_shared_gate, + model->buf_shared_out, HIDDEN_DIM, SHARED_INTERMEDIATE, L.qt_sd); + + // Shared expert gate score (can overlap with I/O) + do_matvec(L.seg_w, L.seg_s, L.seg_b, model->buf_normed, + model->buf_gate_scores, 1, HIDDEN_DIM, L.qt_seg); + + if (g_timing_enabled) { CHECK_CUDA(cudaDeviceSynchronize()); t1 = now_ms(); g_layer_timing.shared_expert += t1-t0; t0=t1; } + + // 6. Load K experts — check VRAM cache first, then SSD + // expert_ptrs[k] points to expert data in VRAM (cache or freshly loaded) + void *expert_ptrs[MAX_K]; + { + int need_ssd[MAX_K]; // indices of experts that need SSD load + int need_ssd_ids[MAX_K]; + int n_ssd = 0; + + model->vram_cache_clock++; + + for (int k = 0; k < K; k++) { + int eid = expert_ids[k]; + int slot = model->cache_map[layer_idx][eid]; + if (slot >= 0 && model->cache_slots[slot].layer == layer_idx && + model->cache_slots[slot].expert_id == eid) { + // Cache hit — point directly at VRAM cache slot + expert_ptrs[k] = (char *)model->vram_cache_pool + (size_t)slot * g_expert_size; + model->cache_slots[slot].last_used = model->vram_cache_clock; + model->cache_slots[slot].access_count++; + } else { + // Cache miss — need SSD load + need_ssd[n_ssd] = k; + need_ssd_ids[n_ssd] = eid; + n_ssd++; + } + } + + if (n_ssd > 0) { + // Load missing experts from SSD + pthread_t threads[MAX_K]; + PreadArg args[MAX_K]; + int fd = model->expert_fds[layer_idx]; + for (int i = 0; i < n_ssd; i++) { + args[i].fd = fd; + args[i].buf = model->h_expert_buf[i]; + args[i].size = g_expert_size; + args[i].offset = (off_t)need_ssd_ids[i] * g_expert_size; + pthread_create(&threads[i], NULL, pread_worker, &args[i]); + } + for (int i = 0; i < n_ssd; i++) + pthread_join(threads[i], NULL); + + // Copy to VRAM cache slots (or buf_expert_data if cache full) + for (int i = 0; i < n_ssd; i++) { + int k = need_ssd[i]; + int eid = need_ssd_ids[i]; + int slot = -1; + + if (model->vram_cache_used < model->vram_cache_capacity) { + // Free slot available + slot = model->vram_cache_used++; + } else if (model->vram_cache_capacity > 0) { + // Evict: frequency-weighted LRU + // Score = access_count * FREQ_WEIGHT + last_used + // Higher score = more valuable = keep longer + // Evict the slot with the lowest score + #define FREQ_WEIGHT 10 // each access = 10 clock ticks of recency + uint64_t min_score = UINT64_MAX; + int min_slot = 0; + for (int s = 0; s < model->vram_cache_capacity; s++) { + uint64_t score = (uint64_t)model->cache_slots[s].access_count * FREQ_WEIGHT + + model->cache_slots[s].last_used; + if (score < min_score) { + min_score = score; + min_slot = s; + } + } + slot = min_slot; + // Remove old entry from cache_map + if (model->cache_slots[slot].layer >= 0) + model->cache_map[model->cache_slots[slot].layer] + [model->cache_slots[slot].expert_id] = -1; + } + + if (slot >= 0) { + void *dst = (char *)model->vram_cache_pool + (size_t)slot * g_expert_size; + // Copy to temp buffer first (for immediate use), then async to cache + void *tmp = (char *)model->buf_expert_data + k * g_expert_size; + CHECK_CUDA(cudaMemcpy(tmp, model->h_expert_buf[i], g_expert_size, + cudaMemcpyHostToDevice)); + // Async copy to cache slot (runs in background) + CHECK_CUDA(cudaMemcpyAsync(dst, tmp, g_expert_size, + cudaMemcpyDeviceToDevice, model->stream_transfer)); + model->cache_slots[slot].layer = layer_idx; + model->cache_slots[slot].expert_id = eid; + model->cache_slots[slot].last_used = model->vram_cache_clock; + model->cache_slots[slot].access_count = 1; + model->cache_map[layer_idx][eid] = slot; + expert_ptrs[k] = tmp; // use temp buffer now, cache fills in background + } else { + // No cache — use temporary buffer + CHECK_CUDA(cudaMemcpy( + (char *)model->buf_expert_data + k * g_expert_size, + model->h_expert_buf[i], g_expert_size, cudaMemcpyHostToDevice)); + expert_ptrs[k] = (char *)model->buf_expert_data + k * g_expert_size; + } + } + } + } + + if (g_timing_enabled) { t1 = now_ms(); g_layer_timing.expert_io += t1-t0; t0=t1; } + + // 7. Expert forward (K experts on GPU, using cached pointers) + for (int k = 0; k < K; k++) { + if (g_quant_format == 1) { + // GGUF expert layout + uint8_t *base = (uint8_t *)expert_ptrs[k]; + launch_dequant_matvec_gguf((const void *)(base + g_gguf_gate_offset), + model->buf_normed, model->buf_shared_gate, MOE_INTERMEDIATE, HIDDEN_DIM, + g_gguf_gate_type); + launch_dequant_matvec_gguf((const void *)(base + g_gguf_up_offset), + model->buf_normed, model->buf_shared_up, MOE_INTERMEDIATE, HIDDEN_DIM, + g_gguf_up_type); + launch_swiglu(model->buf_shared_gate, model->buf_shared_up, model->buf_shared_gate, + MOE_INTERMEDIATE); + launch_dequant_matvec_gguf((const void *)(base + g_gguf_down_offset), + model->buf_shared_gate, model->buf_expert_outs + k * HIDDEN_DIM, + HIDDEN_DIM, MOE_INTERMEDIATE, g_gguf_down_type_per_layer[layer_idx]); + } else { + // MLX expert layout + void *base = expert_ptrs[k]; + + uint32_t *gate_w = (uint32_t *)((char *)base + EXP_GATE_W); + uint16_t *gate_s = (uint16_t *)((char *)base + EXP_GATE_S); + uint16_t *gate_b = (uint16_t *)((char *)base + EXP_GATE_B); + uint32_t *up_w = (uint32_t *)((char *)base + EXP_UP_W); + uint16_t *up_s = (uint16_t *)((char *)base + EXP_UP_S); + uint16_t *up_b = (uint16_t *)((char *)base + EXP_UP_B); + uint32_t *down_w = (uint32_t *)((char *)base + EXP_DOWN_W); + uint16_t *down_s = (uint16_t *)((char *)base + EXP_DOWN_S); + uint16_t *down_b = (uint16_t *)((char *)base + EXP_DOWN_B); + + launch_dequant_matvec(gate_w, gate_s, gate_b, model->buf_normed, + model->buf_shared_gate, MOE_INTERMEDIATE, HIDDEN_DIM); + launch_dequant_matvec(up_w, up_s, up_b, model->buf_normed, + model->buf_shared_up, MOE_INTERMEDIATE, HIDDEN_DIM); + launch_swiglu(model->buf_shared_gate, model->buf_shared_up, model->buf_shared_gate, + MOE_INTERMEDIATE); + launch_dequant_matvec(down_w, down_s, down_b, model->buf_shared_gate, + model->buf_expert_outs + k * HIDDEN_DIM, HIDDEN_DIM, MOE_INTERMEDIATE); + } + } + + // 8. MoE combine + residual (no per-layer malloc) + float h_seg_score; + CHECK_CUDA(cudaMemcpy(&h_seg_score, model->buf_gate_scores, sizeof(float), cudaMemcpyDeviceToHost)); + + CHECK_CUDA(cudaMemcpy(model->buf_expert_weights, expert_weights, + K * sizeof(float), cudaMemcpyHostToDevice)); + + if (g_timing_enabled) { CHECK_CUDA(cudaDeviceSynchronize()); t1 = now_ms(); g_layer_timing.expert_compute += t1-t0; t0=t1; } + + moe_combine_residual<<<(HIDDEN_DIM + 255) / 256, 256>>>( + model->buf_h_mid, model->buf_shared_out, model->buf_hidden, + model->buf_expert_outs, model->buf_expert_weights, h_seg_score, + HIDDEN_DIM, K); + + if (g_timing_enabled) { CHECK_CUDA(cudaDeviceSynchronize()); t1 = now_ms(); g_layer_timing.combine += t1-t0; + g_layer_timing.count++; + } +} + +// ============================================================================ +// Full forward pass: embedding → 60 layers → norm → lm_head → argmax +// ============================================================================ + +static int forward(Model *model, int token_id, int pos, int K) { + // Embedding + embed_token(model, token_id); + + // Reset timing + if (g_timing_enabled) memset(&g_layer_timing, 0, sizeof(g_layer_timing)); + + // 60 layers + for (int i = 0; i < NUM_LAYERS; i++) { + layer_forward(model, i, pos, K); + // Dump hidden state every layer (first token only) + static int layer_dump = 1; + if (layer_dump && g_quant_format == 1) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + CHECK_CUDA(cudaMemcpy(d5, model->buf_hidden, 5*sizeof(float), cudaMemcpyDeviceToHost)); + float mag = 0; for (int j = 0; j < 5; j++) mag += d5[j]*d5[j]; + printf("[ref] L%02d hidden = %10.6f %10.6f %10.6f %10.6f %10.6f mag=%.4f\n", + i, d5[0],d5[1],d5[2],d5[3],d5[4], sqrtf(mag)); + if (i == NUM_LAYERS-1) layer_dump = 0; + } + } + + // Print timing summary + if (g_timing_enabled && g_layer_timing.count > 0) { + int n = g_layer_timing.count; + fprintf(stderr, "[timing] Per-layer avg (%.0f layers): " + "norm=%.2f attn=%.2f oproj=%.2f route=%.2f " + "shared=%.2f io=%.2f expert=%.2f combine=%.2f ms\n", + (double)n, + g_layer_timing.input_norm/n, g_layer_timing.attn_compute/n, + g_layer_timing.oproj_residual/n, g_layer_timing.routing/n, + g_layer_timing.shared_expert/n, g_layer_timing.expert_io/n, + g_layer_timing.expert_compute/n, g_layer_timing.combine/n); + } + + // Dump final hidden state for comparison with llama.cpp + if (g_quant_format == 1) { + static int final_dump = 1; + if (final_dump) { + float d5[5]; + CHECK_CUDA(cudaDeviceSynchronize()); + CHECK_CUDA(cudaMemcpy(d5, model->buf_hidden, 5*sizeof(float), cudaMemcpyDeviceToHost)); + printf("[ref] final_hidden[0:5] = %.6f %.6f %.6f %.6f %.6f\n", d5[0],d5[1],d5[2],d5[3],d5[4]); + printf("[ref] llama.cpp ref: -1.087252 -2.072342 0.351115 3.771449 -0.681070\n"); + final_dump = 0; + } } + + // Final RMS norm + do_rms_norm(model->buf_hidden, model->final_norm_w, model->buf_normed, + HIDDEN_DIM, RMS_NORM_EPS); + // LM head: [VOCAB_SIZE, HIDDEN_DIM] → logits + do_matvec(model->lm_head_w, model->lm_head_s, model->lm_head_b, + model->buf_normed, model->buf_logits, + VOCAB_SIZE, HIDDEN_DIM, model->qt_lm_head); + CHECK_CUDA(cudaDeviceSynchronize()); + + // Copy logits to host and argmax + CHECK_CUDA(cudaMemcpy(model->h_logits, model->buf_logits, + VOCAB_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + + // Debug: save full logits to file for correlation analysis + static int logit_save = 1; + if (logit_save && g_quant_format == 1) { + FILE *lf = fopen("/tmp/cuda_logits_hello.bin", "wb"); + if (lf) { + fwrite(model->h_logits, sizeof(float), VOCAB_SIZE, lf); + fclose(lf); + printf("[ref] Saved %d logits to /tmp/cuda_logits_hello.bin\n", VOCAB_SIZE); + } + logit_save = 0; + } + + // Debug: show top-10 logits for first token + static int logit_dump = 1; + if (logit_dump && g_quant_format == 1) { + int top[10]; float topv[10]; + for (int i = 0; i < 10; i++) { top[i] = i; topv[i] = model->h_logits[i]; } + for (int i = 10; i < VOCAB_SIZE; i++) { + int mn = 0; + for (int j = 1; j < 10; j++) if (topv[j] < topv[mn]) mn = j; + if (model->h_logits[i] > topv[mn]) { top[mn] = i; topv[mn] = model->h_logits[i]; } + } + // Sort by value + for (int i = 0; i < 9; i++) for (int j = i+1; j < 10; j++) + if (topv[j] > topv[i]) { float tv=topv[i]; topv[i]=topv[j]; topv[j]=tv; int ti=top[i]; top[i]=top[j]; top[j]=ti; } + printf("[ref] Top-10 logits:\n"); + for (int i = 0; i < 10; i++) printf(" #%d: token %d = %.4f\n", i+1, top[i], topv[i]); + logit_dump = 0; + } + + int best = 0; + float best_val = model->h_logits[0]; + for (int i = 1; i < VOCAB_SIZE; i++) { + if (model->h_logits[i] > best_val) { + best_val = model->h_logits[i]; + best = i; + } + } + return best; +} + +// ============================================================================ +// HTTP Server — OpenAI-compatible /v1/chat/completions (SSE streaming) +// ============================================================================ + +#include +#include +#include + +static int read_http_request(int fd, char *buf, int bufsz) { + int total = 0; + while (total < bufsz - 1) { + ssize_t r = read(fd, buf + total, 1); + if (r <= 0) return -1; + total++; + if (total >= 4 && buf[total-4]=='\r' && buf[total-3]=='\n' && + buf[total-2]=='\r' && buf[total-1]=='\n') break; + } + buf[total] = '\0'; + // Read body if Content-Length present + const char *cl = strcasestr(buf, "Content-Length:"); + if (cl) { + int content_len = atoi(cl + 15); + if (content_len > 0 && total + content_len < bufsz - 1) { + int got = 0; + while (got < content_len) { + ssize_t r = read(fd, buf + total + got, content_len - got); + if (r <= 0) break; + got += r; + } + total += got; + buf[total] = '\0'; + } + } + return total; +} + +static void http_write_str(int fd, const char *s) { + int len = strlen(s), sent = 0; + while (sent < len) { + ssize_t w = write(fd, s + sent, len - sent); + if (w <= 0) break; + sent += w; + } +} + +static char *extract_last_content(char *buf) { + char *last = NULL, *p = buf; + for (;;) { + p = strstr(p, "\"content\""); + if (!p) break; + p += 9; + while (*p == ' ' || *p == '\t' || *p == ':') p++; + if (*p == '"') { p++; last = p; while (*p && !(*p == '"' && *(p-1) != '\\')) p++; } + } + if (last) { + char *end = last; + while (*end && !(*end == '"' && (end == last || *(end-1) != '\\'))) end++; + *end = '\0'; + // Unescape + char *r = last, *w = last; + while (*r) { + if (*r == '\\' && *(r+1)) { + r++; + switch (*r) { + case 'n': *w++ = '\n'; r++; break; + case 't': *w++ = '\t'; r++; break; + case '"': *w++ = '"'; r++; break; + case '\\': *w++ = '\\'; r++; break; + default: *w++ = '\\'; *w++ = *r++; break; + } + } else *w++ = *r++; + } + *w = '\0'; + } + return last; +} + +// Extract a string value for a given key from JSON body. Returns 0 if not found. +static int extract_string_field(const char *buf, const char *key, char *out, int out_sz) { + char pattern[256]; + snprintf(pattern, sizeof(pattern), "\"%s\"", key); + const char *p = strstr(buf, pattern); + if (!p) return 0; + p += strlen(pattern); + while (*p == ' ' || *p == ':' || *p == '\t') p++; + if (*p != '"') return 0; + p++; + int i = 0; + while (*p && *p != '"' && i < out_sz - 1) out[i++] = *p++; + out[i] = '\0'; + return i > 0; +} + +static int extract_max_tokens(const char *buf, int def) { + const char *p = strstr(buf, "\"max_completion_tokens\""); + if (!p) p = strstr(buf, "\"max_tokens\""); + if (!p) return def; + p = strchr(p, ':'); + return p ? atoi(p + 1) : def; +} + +static int sse_send_delta(int fd, const char *req_id, const char *token_text) { + char chunk[4096], escaped[2048]; + char *w = escaped; + for (const char *r = token_text; *r && w < escaped + sizeof(escaped) - 8; r++) { + switch (*r) { + case '"': *w++ = '\\'; *w++ = '"'; break; + case '\\': *w++ = '\\'; *w++ = '\\'; break; + case '\n': *w++ = '\\'; *w++ = 'n'; break; + case '\r': *w++ = '\\'; *w++ = 'r'; break; + case '\t': *w++ = '\\'; *w++ = 't'; break; + default: *w++ = *r; break; + } + } + *w = '\0'; + int n = snprintf(chunk, sizeof(chunk), + "data: {\"id\":\"%s\",\"object\":\"chat.completion.chunk\"," + "\"choices\":[{\"index\":0,\"delta\":{\"content\":\"%s\"},\"finish_reason\":null}]}\n\n", + req_id, escaped); + ssize_t wr = write(fd, chunk, n); + return (wr <= 0) ? -1 : 0; +} + +static void sse_send_done(int fd, const char *req_id) { + char chunk[1024]; + int n = snprintf(chunk, sizeof(chunk), + "data: {\"id\":\"%s\",\"object\":\"chat.completion.chunk\"," + "\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n" + "data: [DONE]\n\n", req_id); + http_write_str(fd, chunk); +} + +#define EOS_TOKEN_1 248044 // <|endoftext|> +#define EOS_TOKEN_2 248046 // <|im_end|> +#define IM_START 248045 // <|im_start|> +#define THINK_START 248068 // +#define THINK_END 248069 // + +// ============================================================================ +// Tool calling support — extract tools from request, format for Qwen, parse output +// ============================================================================ + +// Extract the "tools" JSON array from the request body (returns malloc'd string or NULL) +static char *extract_tools_json(const char *body) { + const char *p = strstr(body, "\"tools\""); + if (!p) return NULL; + p = strchr(p + 7, '['); + if (!p) return NULL; + // Find matching ] + int depth = 1; + const char *start = p; + p++; + while (*p && depth > 0) { + if (*p == '[') depth++; + else if (*p == ']') depth--; + p++; + } + if (depth != 0) return NULL; + size_t len = p - start; + char *result = (char *)malloc(len + 1); + memcpy(result, start, len); + result[len] = '\0'; + return result; +} + +// Build a full ChatML prompt from the OpenAI messages array + tools. +// Returns malloc'd string ready for tokenization. +// Build a per-request prompt from OpenAI messages array + tools. +// System prompt is already in KV cache — this only generates the user turn(s). +static char *build_chat_prompt(const char *body, const char *tools_json) { + size_t bufsize = strlen(body) * 2 + (tools_json ? strlen(tools_json) * 2 : 0) + 65536; + char *prompt = (char *)calloc(1, bufsize); + char *w = prompt; + + // If tools provided, inject as a system addendum before user messages + if (tools_json) { + w += sprintf(w, "<|im_start|>system\n# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n\n%s\n\n\n" + "For each function call, return a json object with function name and arguments within " + " XML tags:\n\n" + "{\"name\": \"\", \"arguments\": {}}\n" + "<|im_end|>\n", tools_json); + } + + // Parse messages array — find each role/content pair + const char *msgs = strstr(body, "\"messages\""); + if (!msgs) { w += sprintf(w, "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"); return prompt; } + + const char *arr = strchr(msgs, '['); + if (!arr) { w += sprintf(w, "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"); return prompt; } + + // Simple message parser: find each {"role":"...", "content":"..."} pair + const char *p = arr + 1; + while (*p) { + // Find next "role" + const char *role_key = strstr(p, "\"role\""); + if (!role_key) break; + const char *role_val = strchr(role_key + 6, '"'); + if (!role_val) break; + role_val++; // skip quote + const char *role_end = strchr(role_val, '"'); + if (!role_end) break; + + char role[32] = {}; + size_t rlen = role_end - role_val; + if (rlen >= sizeof(role)) rlen = sizeof(role) - 1; + memcpy(role, role_val, rlen); + + // Find content for this message + const char *content_key = strstr(role_end, "\"content\""); + if (!content_key) { p = role_end + 1; continue; } + + // Skip to value + const char *cv = content_key + 9; + while (*cv == ' ' || *cv == ':' || *cv == '\t') cv++; + + if (*cv == '"') { + cv++; // skip opening quote + // Find end quote (handle escapes) + const char *ce = cv; + while (*ce && !(*ce == '"' && *(ce-1) != '\\')) ce++; + + // Write ChatML turn + if (strcmp(role, "system") == 0) { + // System message already handled above, skip + } else if (strcmp(role, "tool") == 0) { + // Tool result — find tool_call_id and name + w += sprintf(w, "<|im_start|>user\n[Tool result]: "); + size_t clen = ce - cv; + memcpy(w, cv, clen); w += clen; + w += sprintf(w, "<|im_end|>\n"); + } else { + w += sprintf(w, "<|im_start|>%s\n", role); + // Copy content, unescaping JSON escapes + const char *r = cv; + while (r < ce) { + if (*r == '\\' && r + 1 < ce) { + r++; + switch (*r) { + case 'n': *w++ = '\n'; break; + case 't': *w++ = '\t'; break; + case '"': *w++ = '"'; break; + case '\\': *w++ = '\\'; break; + default: *w++ = '\\'; *w++ = *r; break; + } + r++; + } else { + *w++ = *r++; + } + } + w += sprintf(w, "<|im_end|>\n"); + } + p = ce + 1; + } else if (*cv == 'n') { + // null content (e.g., assistant tool call message) + if (strcmp(role, "assistant") == 0) { + // Check for tool_calls in this message + const char *tc = strstr(role_end, "\"tool_calls\""); + if (tc) { + w += sprintf(w, "<|im_start|>assistant\n\n"); + // Extract function name and arguments + const char *fname = strstr(tc, "\"name\""); + if (fname) { + const char *fv = strchr(fname + 6, '"'); + if (fv) { + fv++; + const char *fe = strchr(fv, '"'); + if (fe) { + w += sprintf(w, "{\"name\": \""); + memcpy(w, fv, fe - fv); w += (fe - fv); + w += sprintf(w, "\", \"arguments\": "); + } + } + } + const char *fargs = strstr(tc, "\"arguments\""); + if (fargs) { + const char *av = strchr(fargs + 11, '"'); + if (av) { + av++; + const char *ae = av; + while (*ae && !(*ae == '"' && *(ae-1) != '\\')) ae++; + // Unescape and write + const char *r = av; + while (r < ae) { + if (*r == '\\' && r + 1 < ae) { + r++; + switch (*r) { + case 'n': *w++ = '\n'; break; + case '"': *w++ = '"'; break; + case '\\': *w++ = '\\'; break; + default: *w++ = *r; break; + } + r++; + } else *w++ = *r++; + } + } + } + w += sprintf(w, "}\n<|im_end|>\n"); + } + } + p = cv + 4; + } else { + p = cv + 1; + } + } + + // End with assistant prompt + w += sprintf(w, "<|im_start|>assistant\n"); + return prompt; +} + +// Send a tool_call SSE chunk (OpenAI format) +static int sse_send_tool_call(int fd, const char *req_id, const char *call_id, + const char *func_name, const char *arguments) { + char chunk[8192], esc_args[4096]; + // Escape arguments for JSON + char *w = esc_args; + for (const char *r = arguments; *r && w < esc_args + sizeof(esc_args) - 8; r++) { + switch (*r) { + case '"': *w++ = '\\'; *w++ = '"'; break; + case '\\': *w++ = '\\'; *w++ = '\\'; break; + case '\n': *w++ = '\\'; *w++ = 'n'; break; + default: *w++ = *r; break; + } + } + *w = '\0'; + + int n = snprintf(chunk, sizeof(chunk), + "data: {\"id\":\"%s\",\"object\":\"chat.completion.chunk\"," + "\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"%s\"," + "\"type\":\"function\",\"function\":{\"name\":\"%s\",\"arguments\":\"%s\"}}]}," + "\"finish_reason\":null}]}\n\n", + req_id, call_id, func_name, esc_args); + ssize_t wr = write(fd, chunk, n); + return (wr <= 0) ? -1 : 0; +} + +static void sse_send_tool_done(int fd, const char *req_id) { + char chunk[1024]; + int n = snprintf(chunk, sizeof(chunk), + "data: {\"id\":\"%s\",\"object\":\"chat.completion.chunk\"," + "\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n" + "data: [DONE]\n\n", req_id); + http_write_str(fd, chunk); +} + +// Parse a tool call from generated text. Looks for ... pattern. +// Returns 1 if found, fills name and arguments buffers. +static int parse_tool_call(const char *text, char *name, int name_sz, char *args, int args_sz) { + const char *start = strstr(text, ""); + if (!start) return 0; + start += 11; // skip tag + const char *end = strstr(start, ""); + + // Find "name" in the JSON + const char *np = strstr(start, "\"name\""); + if (!np || (end && np > end)) return 0; + np = strchr(np + 6, '"'); + if (!np) return 0; + np++; + const char *ne = strchr(np, '"'); + if (!ne) return 0; + int nlen = ne - np; + if (nlen >= name_sz) nlen = name_sz - 1; + memcpy(name, np, nlen); + name[nlen] = '\0'; + + // Find "arguments" in the JSON + const char *ap = strstr(start, "\"arguments\""); + if (!ap || (end && ap > end)) { args[0] = '{'; args[1] = '}'; args[2] = '\0'; return 1; } + ap = strchr(ap + 11, '{'); + if (!ap) { args[0] = '{'; args[1] = '}'; args[2] = '\0'; return 1; } + // Find matching } + int depth = 1; + const char *aps = ap + 1; + while (*aps && depth > 0) { + if (*aps == '{') depth++; + else if (*aps == '}') depth--; + aps++; + } + int alen = aps - ap; + if (alen >= args_sz) alen = args_sz - 1; + memcpy(args, ap, alen); + args[alen] = '\0'; + return 1; +} + +// ============================================================================ +// Anthropic Messages API — SSE helpers +// ============================================================================ + +static void anth_send_message_start(int fd, const char *msg_id, const char *model_name) { + char buf[2048]; + int n = snprintf(buf, sizeof(buf), + "event: message_start\n" + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"%s\",\"type\":\"message\"," + "\"role\":\"assistant\",\"content\":[],\"model\":\"%s\"," + "\"stop_reason\":null,\"stop_sequence\":null," + "\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n", + msg_id, model_name); + http_write_str(fd, buf); +} + +static void anth_send_content_block_start(int fd, int index, const char *type) { + char buf[512]; + if (strcmp(type, "text") == 0) { + snprintf(buf, sizeof(buf), + "event: content_block_start\n" + "data: {\"type\":\"content_block_start\",\"index\":%d," + "\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n", index); + } + http_write_str(fd, buf); +} + +static void anth_send_content_block_start_tool(int fd, int index, const char *tool_id, + const char *func_name) { + char buf[1024]; + snprintf(buf, sizeof(buf), + "event: content_block_start\n" + "data: {\"type\":\"content_block_start\",\"index\":%d," + "\"content_block\":{\"type\":\"tool_use\",\"id\":\"%s\",\"name\":\"%s\",\"input\":{}}}\n\n", + index, tool_id, func_name); + http_write_str(fd, buf); +} + +static int anth_send_text_delta(int fd, int index, const char *text) { + char chunk[4096], escaped[2048]; + char *w = escaped; + for (const char *r = text; *r && w < escaped + sizeof(escaped) - 8; r++) { + switch (*r) { + case '"': *w++ = '\\'; *w++ = '"'; break; + case '\\': *w++ = '\\'; *w++ = '\\'; break; + case '\n': *w++ = '\\'; *w++ = 'n'; break; + case '\r': *w++ = '\\'; *w++ = 'r'; break; + case '\t': *w++ = '\\'; *w++ = 't'; break; + default: *w++ = *r; break; + } + } + *w = '\0'; + int n = snprintf(chunk, sizeof(chunk), + "event: content_block_delta\n" + "data: {\"type\":\"content_block_delta\",\"index\":%d," + "\"delta\":{\"type\":\"text_delta\",\"text\":\"%s\"}}\n\n", + index, escaped); + ssize_t wr = write(fd, chunk, n); + return (wr <= 0) ? -1 : 0; +} + +static int anth_send_tool_input_delta(int fd, int index, const char *json_delta) { + char chunk[8192], escaped[4096]; + char *w = escaped; + for (const char *r = json_delta; *r && w < escaped + sizeof(escaped) - 8; r++) { + switch (*r) { + case '"': *w++ = '\\'; *w++ = '"'; break; + case '\\': *w++ = '\\'; *w++ = '\\'; break; + case '\n': *w++ = '\\'; *w++ = 'n'; break; + default: *w++ = *r; break; + } + } + *w = '\0'; + int n = snprintf(chunk, sizeof(chunk), + "event: content_block_delta\n" + "data: {\"type\":\"content_block_delta\",\"index\":%d," + "\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"%s\"}}\n\n", + index, escaped); + ssize_t wr = write(fd, chunk, n); + return (wr <= 0) ? -1 : 0; +} + +static void anth_send_content_block_stop(int fd, int index) { + char buf[256]; + snprintf(buf, sizeof(buf), + "event: content_block_stop\n" + "data: {\"type\":\"content_block_stop\",\"index\":%d}\n\n", index); + http_write_str(fd, buf); +} + +static void anth_send_message_delta(int fd, const char *stop_reason, int output_tokens) { + char buf[512]; + snprintf(buf, sizeof(buf), + "event: message_delta\n" + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"%s\",\"stop_sequence\":null}," + "\"usage\":{\"output_tokens\":%d}}\n\n", + stop_reason, output_tokens); + http_write_str(fd, buf); +} + +static void anth_send_message_stop(int fd) { + http_write_str(fd, + "event: message_stop\n" + "data: {\"type\":\"message_stop\"}\n\n"); +} + +// Build per-request ChatML prompt from Anthropic Messages API request. +// System prompt is already in KV cache — this only generates user turn(s) + tools. +static char *build_anthropic_prompt(const char *body, const char *system_prompt) { + size_t bufsize = strlen(body) * 2 + 65536; + char *prompt = (char *)calloc(1, bufsize); + char *w = prompt; + + // If tools provided, inject as system addendum + char *tools_json = extract_tools_json(body); + if (tools_json) { + w += sprintf(w, "<|im_start|>system\n# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n\n" + "%s\n\n\n" + "For each function call, return a json object with function name and arguments within " + " XML tags:\n\n" + "{\"name\": \"\", \"arguments\": {}}\n" + "<|im_end|>\n", tools_json); + free(tools_json); + } + + // Parse messages array + const char *msgs = strstr(body, "\"messages\""); + if (!msgs) { w += sprintf(w, "<|im_start|>assistant\n"); return prompt; } + const char *arr = strchr(msgs, '['); + if (!arr) { w += sprintf(w, "<|im_start|>assistant\n"); return prompt; } + + // Iterate messages — Anthropic format: + // {"role": "user"/"assistant", "content": "string" or [{"type":"text","text":"..."},{"type":"tool_result",...}]} + const char *p = arr + 1; + while (*p) { + const char *role_key = strstr(p, "\"role\""); + if (!role_key) break; + const char *role_val = strchr(role_key + 6, '"'); + if (!role_val) break; + role_val++; + const char *role_end = strchr(role_val, '"'); + if (!role_end) break; + + char role[32] = {}; + size_t rlen = role_end - role_val; + if (rlen >= sizeof(role)) rlen = sizeof(role) - 1; + memcpy(role, role_val, rlen); + + // Find content + const char *content_key = strstr(role_end, "\"content\""); + if (!content_key) { p = role_end + 1; continue; } + const char *cv = content_key + 9; + while (*cv == ' ' || *cv == ':' || *cv == '\t') cv++; + + if (*cv == '"') { + // Simple string content + cv++; + const char *ce = cv; + while (*ce && !(*ce == '"' && *(ce-1) != '\\')) ce++; + + w += sprintf(w, "<|im_start|>%s\n", role); + const char *r = cv; + while (r < ce) { + if (*r == '\\' && r + 1 < ce) { + r++; + switch (*r) { case 'n': *w++ = '\n'; break; case 't': *w++ = '\t'; break; + case '"': *w++ = '"'; break; case '\\': *w++ = '\\'; break; + default: *w++ = '\\'; *w++ = *r; break; } + r++; + } else *w++ = *r++; + } + w += sprintf(w, "<|im_end|>\n"); + p = ce + 1; + } else if (*cv == '[') { + // Array content — may contain text blocks and tool_result blocks + w += sprintf(w, "<|im_start|>%s\n", role); + // Find matching ] + int depth = 1; + const char *as = cv + 1; + while (*as && depth > 0) { + if (*as == '[') depth++; + else if (*as == ']') depth--; + if (depth > 0) as++; + } + // Scan for text and tool_result blocks within the array + const char *scan = cv + 1; + while (scan < as) { + const char *type_key = strstr(scan, "\"type\""); + if (!type_key || type_key >= as) break; + const char *tv = strchr(type_key + 6, '"'); + if (!tv || tv >= as) break; + tv++; + if (strncmp(tv, "text\"", 5) == 0) { + // Find "text" field + const char *tk = strstr(tv, "\"text\""); + if (tk && tk < as) { + const char *tval = strchr(tk + 6, '"'); + if (tval) { tval++; + const char *te = tval; + while (*te && !(*te == '"' && *(te-1) != '\\')) te++; + const char *r = tval; + while (r < te) { + if (*r == '\\' && r + 1 < te) { r++; + switch (*r) { case 'n': *w++ = '\n'; break; case '"': *w++ = '"'; break; + case '\\': *w++ = '\\'; break; default: *w++ = *r; break; } + r++; + } else *w++ = *r++; + } + scan = te + 1; + continue; + } + } + } else if (strncmp(tv, "tool_use\"", 9) == 0) { + // Tool use block from assistant — add as + w += sprintf(w, "\n"); + const char *nk = strstr(tv, "\"name\""); + if (nk && nk < as) { + const char *nv = strchr(nk + 6, '"'); if (nv) { nv++; + const char *ne = strchr(nv, '"'); + if (ne) { w += sprintf(w, "{\"name\": \""); memcpy(w, nv, ne-nv); w += (ne-nv); w += sprintf(w, "\", \"arguments\": "); } + } + } + const char *ik = strstr(tv, "\"input\""); + if (ik && ik < as) { + const char *iv = strchr(ik + 7, '{'); + if (iv) { int d = 1; const char *ie = iv + 1; + while (*ie && d > 0) { if (*ie == '{') d++; else if (*ie == '}') d--; ie++; } + memcpy(w, iv, ie - iv); w += (ie - iv); + } + } + w += sprintf(w, "}\n"); + scan = tv + 9; + continue; + } else if (strncmp(tv, "tool_result\"", 12) == 0) { + // Tool result — extract content + w += sprintf(w, "[Tool result]: "); + const char *ck = strstr(tv, "\"content\""); + if (ck && ck < as) { + const char *cval = strchr(ck + 9, '"'); + if (cval) { cval++; + const char *ce = cval; + while (*ce && !(*ce == '"' && *(ce-1) != '\\')) ce++; + memcpy(w, cval, ce - cval); w += (ce - cval); + scan = ce + 1; + continue; + } + } + } + scan = tv + 5; + } + w += sprintf(w, "<|im_end|>\n"); + p = as + 1; + } else { + p = cv + 1; + } + } + + w += sprintf(w, "<|im_start|>assistant\n"); + return prompt; +} + +static void serve_loop(Model *model, char **vocab_strings, bpe_tokenizer *tokenizer, + int port, int K) { + signal(SIGPIPE, SIG_IGN); + + int server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (server_fd < 0) { perror("socket"); return; } + int opt = 1; + setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(port); + + if (bind(server_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + perror("bind"); close(server_fd); return; + } + if (listen(server_fd, 8) < 0) { + perror("listen"); close(server_fd); return; + } + + printf("[serve] Listening on http://0.0.0.0:%d\n", port); + printf("[serve] Endpoints:\n"); + printf("[serve] POST /v1/chat/completions (OpenAI format)\n"); + printf("[serve] POST /v1/messages (Anthropic format)\n"); + printf("[serve] GET /v1/models\n"); + printf("[serve] GET /health\n"); + fflush(stdout); + + size_t delta_sz = LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM * sizeof(float); + size_t conv_sz = (CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM * sizeof(float); + int kv_dim = NUM_KV_HEADS * HEAD_DIM; + + // ---- System prompt pre-cache ---- + // Tokenize and prefill the system prompt once at startup. + // Snapshot the resulting state so each request restores from here + // instead of starting from scratch (~8s saved per request). + const char *sys_text = "You are a helpful assistant."; + { + const char *home = getenv("HOME"); + if (home) { + char path[1024]; + snprintf(path, sizeof(path), "%s/.flash-moe/system.md", home); + FILE *f = fopen(path, "r"); + if (f) { + fseek(f, 0, SEEK_END); long sz = ftell(f); fseek(f, 0, SEEK_SET); + char *buf = (char *)malloc(sz + 1); + buf[fread(buf, 1, sz, f)] = 0; + fclose(f); + sys_text = buf; + fprintf(stderr, "[serve] Custom system prompt from %s (%ld bytes)\n", path, sz); + } + } + } + + // Build system prompt in ChatML format + char *sys_chatml = (char *)malloc(strlen(sys_text) + 256); + sprintf(sys_chatml, "<|im_start|>system\n%s<|im_end|>\n", sys_text); + + uint32_t sys_ids[4096]; + int sys_ntokens = bpe_encode(tokenizer, sys_chatml, sys_ids, 4096); + free(sys_chatml); + + fprintf(stderr, "[serve] System prompt: %d tokens, prefilling...\n", sys_ntokens); + double t_prefill = now_ms(); + for (int i = 0; i < sys_ntokens; i++) + forward(model, sys_ids[i], i, K); + int sys_pos = sys_ntokens; + fprintf(stderr, "[serve] System prompt cached in %.0f ms\n", now_ms() - t_prefill); + + // Snapshot KV caches + delta-net + conv states after system prompt + void *snap_kv_k[NUM_LAYERS] = {}, *snap_kv_v[NUM_LAYERS] = {}; + int snap_kv_len[NUM_LAYERS] = {}; + void *snap_delta[NUM_LAYERS] = {}, *snap_conv[NUM_LAYERS] = {}; + + for (int i = 0; i < NUM_LAYERS; i++) { + if (model->layers[i].is_full) { + size_t sz = sys_pos * kv_dim * sizeof(float); + snap_kv_k[i] = malloc(sz); snap_kv_v[i] = malloc(sz); + CHECK_CUDA(cudaMemcpy(snap_kv_k[i], model->kv_k[i], sz, cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(snap_kv_v[i], model->kv_v[i], sz, cudaMemcpyDeviceToHost)); + snap_kv_len[i] = model->kv_len[i]; + } else { + snap_delta[i] = malloc(delta_sz); snap_conv[i] = malloc(conv_sz); + CHECK_CUDA(cudaMemcpy(snap_delta[i], model->delta_state[i], delta_sz, cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(snap_conv[i], model->conv_state[i], conv_sz, cudaMemcpyDeviceToHost)); + } + } + fprintf(stderr, "[serve] State snapshot saved (%d layers)\n", NUM_LAYERS); + + uint64_t req_counter = 0; + + // Session tracking — keep KV cache across requests in the same session + char active_session[128] = {}; + int session_pos = 0; // RoPE position after last generation + + fprintf(stderr, "[serve] Ready\n"); + + static const char *SSE_HEADERS = + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/event-stream\r\n" + "Cache-Control: no-cache\r\n" + "Connection: close\r\n" + "Access-Control-Allow-Origin: *\r\n" + "\r\n"; + static const char *CORS_RESPONSE = + "HTTP/1.1 204 No Content\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n" + "Access-Control-Allow-Headers: Content-Type, Authorization\r\n" + "Access-Control-Max-Age: 86400\r\n" + "\r\n"; + + for (;;) { + struct sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + int client_fd = accept(server_fd, (struct sockaddr *)&client_addr, &client_len); + if (client_fd < 0) { perror("accept"); continue; } + + char *reqbuf = (char *)malloc(1024 * 1024); + int reqlen = read_http_request(client_fd, reqbuf, 1024 * 1024); + if (reqlen <= 0) { free(reqbuf); close(client_fd); continue; } + + char method[16] = {}, path_buf[256] = {}; + sscanf(reqbuf, "%15s %255s", method, path_buf); + + // CORS preflight + if (strcmp(method, "OPTIONS") == 0) { + http_write_str(client_fd, CORS_RESPONSE); + free(reqbuf); close(client_fd); continue; + } + + // GET /health + if (strcmp(method, "GET") == 0 && strcmp(path_buf, "/health") == 0) { + http_write_str(client_fd, + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + "Access-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n" + "{\"status\":\"ok\",\"model\":\"qwen3.5-397b-a17b-cuda\"}\n"); + free(reqbuf); close(client_fd); continue; + } + + // GET /v1/models + if (strcmp(method, "GET") == 0 && strcmp(path_buf, "/v1/models") == 0) { + http_write_str(client_fd, + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + "Access-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n" + "{\"object\":\"list\",\"data\":[{\"id\":\"qwen3.5-397b-a17b\"," + "\"object\":\"model\",\"owned_by\":\"local\"}]}\n"); + free(reqbuf); close(client_fd); continue; + } + + // POST /v1/chat/completions + if (strcmp(method, "POST") == 0 && strcmp(path_buf, "/v1/chat/completions") == 0) { + char *body = strstr(reqbuf, "\r\n\r\n"); + if (!body) { + http_write_str(client_fd, "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n{\"error\":\"no body\"}\n"); + free(reqbuf); close(client_fd); continue; + } + body += 4; + + int max_gen = extract_max_tokens(body, 4096); + if (max_gen > 32768) max_gen = 32768; + + // Extract tools and session_id + char *tools_json = extract_tools_json(body); + char req_session[128] = {}; + extract_string_field(body, "session_id", req_session, sizeof(req_session)); + + char request_id[64]; + snprintf(request_id, sizeof(request_id), "chatcmpl-%llu", (unsigned long long)++req_counter); + + // Determine if this is a continuation of the active session + int is_continuation = (req_session[0] && active_session[0] && + strcmp(req_session, active_session) == 0); + + fprintf(stderr, "[serve] %s max_tokens=%d tools=%s session=%s%s\n", + request_id, max_gen, tools_json ? "yes" : "no", + req_session[0] ? req_session : "(none)", + is_continuation ? " [CONTINUE]" : " [NEW]"); + + int pos; + if (is_continuation) { + // Continue from existing state — no restore needed + pos = session_pos; + } else { + // New session — restore from system prompt snapshot + for (int i = 0; i < NUM_LAYERS; i++) { + if (model->layers[i].is_full) { + size_t sz = snap_kv_len[i] * kv_dim * sizeof(float); + if (sz > 0) { + CHECK_CUDA(cudaMemcpy(model->kv_k[i], snap_kv_k[i], sz, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(model->kv_v[i], snap_kv_v[i], sz, cudaMemcpyHostToDevice)); + } + model->kv_len[i] = snap_kv_len[i]; + } else { + CHECK_CUDA(cudaMemcpy(model->delta_state[i], snap_delta[i], delta_sz, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(model->conv_state[i], snap_conv[i], conv_sz, cudaMemcpyHostToDevice)); + } + } + pos = sys_pos; + if (req_session[0]) { + strncpy(active_session, req_session, sizeof(active_session) - 1); + } else { + active_session[0] = '\0'; + } + } + + // Build per-request prompt (user turn + tools only, system prompt already cached) + char *prompt = build_chat_prompt(body, tools_json); + if (tools_json) free(tools_json); + + uint32_t turn_ids[16384]; + int turn_ntokens = bpe_encode(tokenizer, prompt, turn_ids, 16384); + free(prompt); + + if (turn_ntokens <= 0) { + http_write_str(client_fd, "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n{\"error\":\"tokenization failed\"}\n"); + free(reqbuf); close(client_fd); continue; + } + + fprintf(stderr, "[serve] %s prompt=%d tokens, pos=%d\n", request_id, turn_ntokens, pos); + + // Send SSE headers + http_write_str(client_fd, SSE_HEADERS); + + // Prefill user turn tokens — last forward() return = first generated token + int next_token = 0; + for (int i = 0; i < turn_ntokens; i++) { + next_token = forward(model, turn_ids[i], pos++, K); + } + + double t_gen = now_ms(); + int gen_count = 0; + int client_ok = 1; + + // Buffer for detecting tool calls in output + char gen_buffer[65536] = {}; + int gen_buf_len = 0; + int in_tool_call = 0; + int tool_call_count = 0; + + // Special token IDs to suppress from output + // These are Qwen3.5 special token IDs that should not appear as content + int suppress_tokens[] = { + EOS_TOKEN_1, // <|endoftext|> + IM_START, // <|im_start|> + EOS_TOKEN_2, // <|im_end|> + }; + int n_suppress = sizeof(suppress_tokens) / sizeof(suppress_tokens[0]); + + for (int gen = 0; gen < max_gen && client_ok; gen++) { + // Stop on EOS tokens + if (next_token == EOS_TOKEN_1 || next_token == EOS_TOKEN_2) break; + + // Check if this is a suppressed special token + int is_special = 0; + for (int s = 0; s < n_suppress; s++) { + if (next_token == suppress_tokens[s]) { is_special = 1; break; } + } + if (is_special) { + // Special token — don't output, just continue generating + gen_count++; + next_token = forward(model, next_token, pos++, K); + continue; + } + + // Decode token + char decoded[1024] = {}; + if (vocab_strings[next_token]) + bpe_decode_token(vocab_strings[next_token], decoded, sizeof(decoded)); + + // Accumulate in buffer for tool call detection + if (gen_buf_len + (int)strlen(decoded) < (int)sizeof(gen_buffer) - 1) { + strcpy(gen_buffer + gen_buf_len, decoded); + gen_buf_len += strlen(decoded); + } + + // Check for tool call start + if (!in_tool_call && strstr(gen_buffer, "")) { + in_tool_call = 1; + // Flush any content before that was already sent + // (the "" text itself was accumulated but not sent) + } + + // Stop if decoded text contains EOS markers + if (strstr(decoded, "<|im_end|>") || strstr(decoded, "<|endoftext|>")) break; + + // Filter special text patterns from content + int is_filtered = ( + strstr(decoded, "<|im_start|>") || + strstr(decoded, "<|im_end|>") || + strstr(decoded, "<|endoftext|>") || + strcmp(decoded, "") == 0 || + strcmp(decoded, "") == 0 || + strcmp(decoded, "user") == 0 || // stray role tokens + strcmp(decoded, "assistant") == 0 || // stray role tokens + strcmp(decoded, "system") == 0 // stray role tokens + ); + + // If not in a tool call and not filtered, stream content + if (!in_tool_call && decoded[0] && !is_filtered) { + if (sse_send_delta(client_fd, request_id, decoded) < 0) { + client_ok = 0; break; + } + } + + // Check for tool call end + if (in_tool_call && strstr(gen_buffer, "")) { + // Parse and send the tool call + char func_name[256] = {}, func_args[4096] = {}; + if (parse_tool_call(gen_buffer, func_name, sizeof(func_name), + func_args, sizeof(func_args))) { + char call_id[64]; + snprintf(call_id, sizeof(call_id), "call_%d", ++tool_call_count); + sse_send_tool_call(client_fd, request_id, call_id, func_name, func_args); + fprintf(stderr, "[serve] %s tool_call: %s(%s)\n", + request_id, func_name, func_args); + } + // Stop generation after tool call — the client needs to + // execute the tool and send results back in a new request + break; + } + + gen_count++; + next_token = forward(model, next_token, pos++, K); + } + + if (client_ok) { + if (tool_call_count > 0) { + sse_send_tool_done(client_fd, request_id); + } else { + sse_send_done(client_fd, request_id); + } + } + + // Save session position for continuation + session_pos = pos; + + double gen_ms = now_ms() - t_gen; + fprintf(stderr, "[serve] %s generated %d tokens in %.0fms (%.2f tok/s)%s pos=%d\n", + request_id, gen_count, gen_ms, + gen_count > 0 ? gen_count / (gen_ms / 1000.0) : 0.0, + tool_call_count > 0 ? " [tool_calls]" : "", pos); + + free(reqbuf); close(client_fd); continue; + } + + // POST /v1/messages (Anthropic Messages API) + if (strcmp(method, "POST") == 0 && strcmp(path_buf, "/v1/messages") == 0) { + char *body = strstr(reqbuf, "\r\n\r\n"); + if (!body) { + http_write_str(client_fd, "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n{\"error\":\"no body\"}\n"); + free(reqbuf); close(client_fd); continue; + } + body += 4; + + int max_gen = extract_max_tokens(body, 4096); + if (max_gen > 32768) max_gen = 32768; + + char req_session[128] = {}; + extract_string_field(body, "session_id", req_session, sizeof(req_session)); + // Also check x-session-id header for Anthropic clients + if (!req_session[0]) { + const char *hdr = strcasestr(reqbuf, "x-session-id:"); + if (hdr) { hdr += 13; while (*hdr == ' ') hdr++; + int i = 0; while (*hdr && *hdr != '\r' && *hdr != '\n' && i < 127) req_session[i++] = *hdr++; + req_session[i] = '\0'; + } + } + + char request_id[64]; + snprintf(request_id, sizeof(request_id), "msg_%llu", (unsigned long long)++req_counter); + + int is_continuation = (req_session[0] && active_session[0] && + strcmp(req_session, active_session) == 0); + + fprintf(stderr, "[serve] %s (anthropic) max_tokens=%d session=%s%s\n", + request_id, max_gen, + req_session[0] ? req_session : "(none)", + is_continuation ? " [CONTINUE]" : " [NEW]"); + + int pos; + if (is_continuation) { + pos = session_pos; + } else { + // Restore from system prompt snapshot + for (int i = 0; i < NUM_LAYERS; i++) { + if (model->layers[i].is_full) { + size_t sz = snap_kv_len[i] * kv_dim * sizeof(float); + if (sz > 0) { + CHECK_CUDA(cudaMemcpy(model->kv_k[i], snap_kv_k[i], sz, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(model->kv_v[i], snap_kv_v[i], sz, cudaMemcpyHostToDevice)); + } + model->kv_len[i] = snap_kv_len[i]; + } else { + CHECK_CUDA(cudaMemcpy(model->delta_state[i], snap_delta[i], delta_sz, cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(model->conv_state[i], snap_conv[i], conv_sz, cudaMemcpyHostToDevice)); + } + } + pos = sys_pos; + if (req_session[0]) + strncpy(active_session, req_session, sizeof(active_session) - 1); + else + active_session[0] = '\0'; + } + + // Build prompt from Anthropic format (user turn only, system prompt cached) + char *prompt = build_anthropic_prompt(body, "You are a helpful assistant."); + uint32_t turn_ids[16384]; + int turn_ntokens = bpe_encode(tokenizer, prompt, turn_ids, 16384); + free(prompt); + + if (turn_ntokens <= 0) { + http_write_str(client_fd, "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n{\"type\":\"error\",\"error\":{\"type\":\"invalid_request_error\",\"message\":\"tokenization failed\"}}\n"); + free(reqbuf); close(client_fd); continue; + } + + fprintf(stderr, "[serve] %s prompt=%d tokens, pos=%d\n", request_id, turn_ntokens, pos); + + // Send SSE headers + static const char *ANTH_SSE_HEADERS = + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/event-stream\r\n" + "Cache-Control: no-cache\r\n" + "Connection: close\r\n" + "Access-Control-Allow-Origin: *\r\n" + "\r\n"; + http_write_str(client_fd, ANTH_SSE_HEADERS); + + // message_start + anth_send_message_start(client_fd, request_id, "qwen3.5-397b-a17b"); + + // Prefill + int next_token = 0; + for (int i = 0; i < turn_ntokens; i++) + next_token = forward(model, turn_ids[i], pos++, K); + + // Start text content block + int block_index = 0; + anth_send_content_block_start(client_fd, block_index, "text"); + + double t_gen = now_ms(); + int gen_count = 0; + int client_ok = 1; + char gen_buffer[65536] = {}; + int gen_buf_len = 0; + int in_tool_call = 0; + int tool_call_count = 0; + const char *stop_reason = "end_turn"; + + // Special tokens to suppress + int suppress_tokens[] = { 151643, 151644, 151645, 151646, 151647, 151648, + 151649, 151650, 151651, 151652, 151653, 151654 }; + int n_suppress = sizeof(suppress_tokens) / sizeof(suppress_tokens[0]); + + for (int gen = 0; gen < max_gen && client_ok; gen++) { + if (next_token == EOS_TOKEN_1 || next_token == EOS_TOKEN_2) break; + + int is_special = 0; + for (int s = 0; s < n_suppress; s++) + if (next_token == suppress_tokens[s]) { is_special = 1; break; } + if (is_special) { + gen_count++; + next_token = forward(model, next_token, pos++, K); + continue; + } + + char decoded[1024] = {}; + if (vocab_strings[next_token]) + bpe_decode_token(vocab_strings[next_token], decoded, sizeof(decoded)); + + if (strstr(decoded, "<|im_end|>") || strstr(decoded, "<|endoftext|>")) break; + + // Accumulate for tool call detection + if (gen_buf_len + (int)strlen(decoded) < (int)sizeof(gen_buffer) - 1) { + strcpy(gen_buffer + gen_buf_len, decoded); + gen_buf_len += strlen(decoded); + } + + if (!in_tool_call && strstr(gen_buffer, "")) + in_tool_call = 1; + + // Stream text content + if (!in_tool_call && decoded[0]) { + int is_filtered = ( + strstr(decoded, "<|im_start|>") || strstr(decoded, "<|im_end|>") || + strstr(decoded, "<|endoftext|>") || + strcmp(decoded, "") == 0 || strcmp(decoded, "") == 0 || + strcmp(decoded, "user") == 0 || strcmp(decoded, "assistant") == 0 || + strcmp(decoded, "system") == 0 + ); + if (!is_filtered) { + if (anth_send_text_delta(client_fd, block_index, decoded) < 0) { + client_ok = 0; break; + } + } + } + + // Tool call detected + if (in_tool_call && strstr(gen_buffer, "")) { + char func_name[256] = {}, func_args[4096] = {}; + if (parse_tool_call(gen_buffer, func_name, sizeof(func_name), + func_args, sizeof(func_args))) { + // Close text block, open tool_use block + anth_send_content_block_stop(client_fd, block_index); + block_index++; + + char tool_id[64]; + snprintf(tool_id, sizeof(tool_id), "toolu_%d", ++tool_call_count); + anth_send_content_block_start_tool(client_fd, block_index, tool_id, func_name); + anth_send_tool_input_delta(client_fd, block_index, func_args); + anth_send_content_block_stop(client_fd, block_index); + + fprintf(stderr, "[serve] %s tool_use: %s(%s)\n", + request_id, func_name, func_args); + stop_reason = "tool_use"; + } + break; + } + + gen_count++; + next_token = forward(model, next_token, pos++, K); + } + + if (client_ok) { + if (tool_call_count == 0) + anth_send_content_block_stop(client_fd, block_index); + anth_send_message_delta(client_fd, stop_reason, gen_count); + anth_send_message_stop(client_fd); + } + + // Save session position + session_pos = pos; + + double gen_ms = now_ms() - t_gen; + fprintf(stderr, "[serve] %s generated %d tokens in %.0fms (%.2f tok/s)%s pos=%d\n", + request_id, gen_count, gen_ms, + gen_count > 0 ? gen_count / (gen_ms / 1000.0) : 0.0, + tool_call_count > 0 ? " [tool_use]" : "", pos); + + free(reqbuf); close(client_fd); continue; + } + + // Unknown endpoint + http_write_str(client_fd, "HTTP/1.1 404 Not Found\r\nConnection: close\r\n\r\n{\"error\":\"not found\"}\n"); + free(reqbuf); close(client_fd); + } +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char **argv) { + setbuf(stdout, NULL); // unbuffered stdout for serve mode + { const char *elog = getenv("EXPERT_LOG"); + if (elog) { g_expert_log = fopen(elog, "w"); } } + const char *weights_path = "model_weights.bin"; + const char *manifest_path = "model_weights.json"; + const char *vocab_path = "vocab.bin"; + const char *tokenizer_path = "tokenizer.bin"; + const char *expert_dir = "packed_experts"; + const char *prompt_text = NULL; + int serve_port = 0; + int max_tokens = 20; + int K = 4; + int timing = 0; + + static struct option long_options[] = { + {"weights", required_argument, 0, 'w'}, + {"manifest", required_argument, 0, 'j'}, + {"vocab", required_argument, 0, 'v'}, + {"tokenizer", required_argument, 0, 'T'}, + {"experts", required_argument, 0, 'e'}, + {"prompt", required_argument, 0, 'P'}, + {"tokens", required_argument, 0, 't'}, + {"k", required_argument, 0, 'k'}, + {"serve", required_argument, 0, 'S'}, + {"timing", no_argument, 0, 'M'}, + {"help", no_argument, 0, 'h'}, + {0, 0, 0, 0} + }; + + int c; + while ((c = getopt_long(argc, argv, "w:j:v:T:e:P:t:k:S:Mh", long_options, NULL)) != -1) { + switch (c) { + case 'w': weights_path = optarg; break; + case 'j': manifest_path = optarg; break; + case 'v': vocab_path = optarg; break; + case 'T': tokenizer_path = optarg; break; + case 'e': expert_dir = optarg; break; + case 'P': prompt_text = optarg; break; + case 't': max_tokens = atoi(optarg); break; + case 'k': K = atoi(optarg); break; + case 'S': serve_port = atoi(optarg); break; + case 'M': timing = 1; g_timing_enabled = 1; break; + case 'h': + printf("Usage: %s --prompt TEXT [options]\n", argv[0]); + printf(" --weights PATH model_weights.bin\n"); + printf(" --manifest PATH model_weights.json\n"); + printf(" --vocab PATH vocab.bin\n"); + printf(" --tokenizer PATH tokenizer.bin\n"); + printf(" --experts PATH packed_experts directory\n"); + printf(" --prompt TEXT input prompt\n"); + printf(" --tokens N max tokens (default: 20)\n"); + printf(" --k N active experts (default: 4)\n"); + printf(" --serve PORT HTTP server (OpenAI-compatible API)\n"); + printf(" --timing per-layer timing\n"); + return 0; + default: return 1; + } + } + + if (!prompt_text && serve_port == 0) { + fprintf(stderr, "Error: --prompt or --serve required\n"); + return 1; + } + // Initialize CUDA + cudaDeviceProp prop; + CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + printf("GPU: %s, VRAM: %zu MB, SM: %d\n", + prop.name, prop.totalGlobalMem / (1024*1024), prop.multiProcessorCount); + + // Load weights + WeightFile *wf = open_weights(weights_path, manifest_path); + if (!wf) return 1; + + // Detect GGUF format from manifest + if ((g_manifest_json && strstr(g_manifest_json, "\"quant_format\": \"gguf\"")) || + (g_manifest_json && strstr(g_manifest_json, "\"quant_format\":\"gguf\""))) { + g_quant_format = 1; + printf("[init] GGUF weight format detected\n"); + g_dump_layer0 = 1; + } + + // Read expert layout for GGUF + if (g_quant_format == 1) { + char layout_path[512]; + snprintf(layout_path, sizeof(layout_path), "%s/layout.json", expert_dir); + FILE *lf = fopen(layout_path, "r"); + if (lf) { + fseek(lf, 0, SEEK_END); + long lsz = ftell(lf); + fseek(lf, 0, SEEK_SET); + char *ljson = (char *)malloc(lsz + 1); + ljson[fread(ljson, 1, lsz, lf)] = '\0'; + fclose(lf); + + // Parse expert_size + const char *es = strstr(ljson, "\"expert_size\""); + if (es) { const char *c = strchr(es, ':'); if (c) g_expert_size = strtoul(c + 1, NULL, 10); } + + // Parse components array for offsets and types + // Format: "components": [{"name": "gate_exps", "offset": 0, "size": N, "gguf_type": 12}, ...] + const char *comp = strstr(ljson, "\"components\""); + if (comp) { + // gate + const char *gate_c = strstr(comp, "\"gate_exps\""); + if (!gate_c) gate_c = strstr(comp, "gate"); + if (gate_c) { + const char *off_k = strstr(gate_c, "\"offset\""); + if (off_k) { const char *c2 = strchr(off_k, ':'); if (c2) g_gguf_gate_offset = strtoul(c2 + 1, NULL, 10); } + const char *sz_k = strstr(gate_c, "\"size\""); + if (sz_k) { const char *c2 = strchr(sz_k, ':'); if (c2) g_gguf_gate_size = strtoul(c2 + 1, NULL, 10); } + const char *gt_k = strstr(gate_c, "\"gguf_type\""); + if (gt_k) { const char *c2 = strchr(gt_k, ':'); if (c2) g_gguf_gate_type = atoi(c2 + 1); } + } + // up (search after gate) + const char *up_c = strstr(gate_c ? gate_c + 1 : comp, "\"up_exps\""); + if (!up_c) up_c = strstr(gate_c ? gate_c + 1 : comp, "up"); + if (up_c) { + const char *off_k = strstr(up_c, "\"offset\""); + if (off_k) { const char *c2 = strchr(off_k, ':'); if (c2) g_gguf_up_offset = strtoul(c2 + 1, NULL, 10); } + const char *sz_k = strstr(up_c, "\"size\""); + if (sz_k) { const char *c2 = strchr(sz_k, ':'); if (c2) g_gguf_up_size = strtoul(c2 + 1, NULL, 10); } + const char *gt_k = strstr(up_c, "\"gguf_type\""); + if (gt_k) { const char *c2 = strchr(gt_k, ':'); if (c2) g_gguf_up_type = atoi(c2 + 1); } + } + // down (search after up) + const char *down_c = strstr(up_c ? up_c + 1 : comp, "\"down_exps\""); + if (!down_c) down_c = strstr(up_c ? up_c + 1 : comp, "down"); + if (down_c) { + const char *off_k = strstr(down_c, "\"offset\""); + if (off_k) { const char *c2 = strchr(off_k, ':'); if (c2) g_gguf_down_offset = strtoul(c2 + 1, NULL, 10); } + const char *sz_k = strstr(down_c, "\"size\""); + if (sz_k) { const char *c2 = strchr(sz_k, ':'); if (c2) g_gguf_down_size = strtoul(c2 + 1, NULL, 10); } + const char *gt_k = strstr(down_c, "\"gguf_type\""); + if (gt_k) { const char *c2 = strchr(gt_k, ':'); if (c2) g_gguf_down_type = atoi(c2 + 1); } + } + } + + printf("[init] GGUF expert layout: size=%zu, gate@%zu(%d) up@%zu(%d) down@%zu(%d)\n", + g_expert_size, + g_gguf_gate_offset, g_gguf_gate_type, + g_gguf_up_offset, g_gguf_up_type, + g_gguf_down_offset, g_gguf_down_type); + + // Parse per-layer down types from "layer_info" array + // Format: "layer_info": [{"down_type": 14, ...}, {"down_type": 12, ...}, ...] + const char *li = strstr(ljson, "\"layer_info\""); + if (li) { + int mixed = 0; + for (int i = 0; i < NUM_LAYERS && i < 256; i++) { + g_gguf_down_type_per_layer[i] = g_gguf_down_type; // default + // Find the i-th "down_type" entry + li = strstr(li + 1, "\"down_type\""); + if (li) { + const char *c2 = strchr(li, ':'); + if (c2) { + int t = atoi(c2 + 1); + g_gguf_down_type_per_layer[i] = t; + if (t != g_gguf_down_type) mixed = 1; + } + } + } + if (mixed) { + int q4k = 0, q6k = 0; + for (int i = 0; i < NUM_LAYERS; i++) { + if (g_gguf_down_type_per_layer[i] == 12) q4k++; + else q6k++; + } + printf("[init] Mixed expert quant: %d layers Q4_K down, %d layers Q6_K down\n", + q4k, q6k); + } + } else { + // No layer_info: use uniform type + for (int i = 0; i < NUM_LAYERS && i < 256; i++) + g_gguf_down_type_per_layer[i] = g_gguf_down_type; + } + free(ljson); + } else { + fprintf(stderr, "WARNING: GGUF format but no %s found\n", layout_path); + } + } + + // Load vocab + // (vocab.bin format: u32 num_entries, u32 max_id, then per entry: u16 len + bytes) + FILE *vf = fopen(vocab_path, "rb"); + if (!vf) { perror(vocab_path); return 1; } + uint32_t vocab_n, vocab_max; + fread(&vocab_n, 4, 1, vf); + fread(&vocab_max, 4, 1, vf); + char **vocab_strings = (char **)calloc(vocab_n, sizeof(char *)); + for (uint32_t i = 0; i < vocab_n; i++) { + uint16_t len; + fread(&len, 2, 1, vf); + if (len > 0) { + vocab_strings[i] = (char *)malloc(len + 1); + fread(vocab_strings[i], 1, len, vf); + vocab_strings[i][len] = '\0'; + } + } + fclose(vf); + printf("[vocab] %u tokens\n", vocab_n); + + // Load tokenizer + bpe_tokenizer tokenizer; + if (bpe_load(&tokenizer, tokenizer_path) < 0) { + fprintf(stderr, "Cannot load tokenizer %s\n", tokenizer_path); + return 1; + } + printf("[tokenizer] Loaded (%d vocab, %d merges)\n", tokenizer.vocab_size, tokenizer.num_merges); + + // Initialize model + Model *model = model_init(wf, expert_dir, K); + if (!model) return 1; + + // Serve mode + if (serve_port > 0) { + serve_loop(model, vocab_strings, &tokenizer, serve_port, K); + return 0; + } + + if (!prompt_text) { fprintf(stderr, "Error: --prompt required\n"); return 1; } + + // Tokenize prompt + uint32_t token_ids_buf[4096]; + int num_tokens = bpe_encode(&tokenizer, prompt_text, token_ids_buf, 4096); + if (num_tokens < 0) { fprintf(stderr, "Tokenization failed\n"); return 1; } + printf("[prompt] \"%s\" → %d tokens:", prompt_text, num_tokens); + for (int i = 0; i < num_tokens && i < 20; i++) printf(" %u", token_ids_buf[i]); + if (num_tokens > 20) printf(" ..."); + printf("\n"); + + printf("\n[generating] %d tokens, K=%d experts\n", max_tokens, K); + double gen_start = now_ms(); + + // Process prompt tokens (prefill) + for (int i = 0; i < num_tokens; i++) { + double t0 = now_ms(); + int next = forward(model, token_ids_buf[i], i, K); + double elapsed = now_ms() - t0; + if (timing) printf("[prefill %d/%d] token=%d, %.1f ms\n", i+1, num_tokens, token_ids_buf[i], elapsed); + if (i == num_tokens - 1) { + // Print first generated token + if (vocab_strings[next]) print_token(vocab_strings[next]); + fflush(stdout); + // Continue generating + int prev = next; + for (int t = 0; t < max_tokens - 1; t++) { + double tt0 = now_ms(); + next = forward(model, prev, num_tokens + t, K); + double telapsed = now_ms() - tt0; + if (vocab_strings[next]) print_token(vocab_strings[next]); + fflush(stdout); + if (timing) printf(" [%.1fms]", telapsed); + prev = next; + // Stop on EOS + if (next == 151643 || next == 151645) break; // <|endoftext|>, <|im_end|> + } + } + } + + double gen_elapsed = now_ms() - gen_start; + printf("\n\n[done] %.1f ms total, %.1f ms/token, %.2f tok/s\n", + gen_elapsed, gen_elapsed / max_tokens, max_tokens / (gen_elapsed / 1000.0)); + + bpe_free(&tokenizer); + return 0; +} diff --git a/cuda_infer/kernels.cuh b/cuda_infer/kernels.cuh new file mode 100644 index 0000000..286c86f --- /dev/null +++ b/cuda_infer/kernels.cuh @@ -0,0 +1,1175 @@ +/* + * kernels.cuh — CUDA compute kernels for Flash-MoE inference + * + * Port of shaders.metal for NVIDIA GPUs (RTX 4090 target). + * All kernels operate on the same quantization format: + * - 4-bit affine quantization, group_size=64 + * - Weights: uint32 holding 8 x 4-bit values + * - Per-group scale and bias in bfloat16 + * + * Kernel list: + * 1. dequant_matvec_4bit_fma — FMA-optimized 4-bit dequant matvec + * 2. swiglu_fused — SiLU(gate) * up + * 3. rms_norm — Fused sum-of-squares + normalize (single kernel) + * 4. rms_norm_bf16 — RMS norm with bf16 weights + * 5. residual_add — a + b + * 6. attn_scores — Q @ K^T (batched over heads) + * 7. attn_softmax — Softmax over seq_len per head + * 8. attn_values — scores @ V (batched over heads) + * 9. sigmoid_gate — x *= sigmoid(gate) + * 10. gated_delta_net_step — GatedDeltaNet recurrence + * 11. conv1d_step — Depthwise conv1d (kernel=4) + SiLU + * 12. rms_norm_qk — Per-head RMS norm for Q and K + * 13. compute_decay_beta — GatedDeltaNet decay and beta gate + * 14. gated_rms_norm — RMS norm with SiLU gate and bf16 weights + * 15. moe_combine_residual — Weighted expert sum + shared expert + residual + */ + +#pragma once +#include +#include +#include + +#define GROUP_SIZE 64 + +// ============================================================================ +// BFloat16 helper +// ============================================================================ + +__device__ __forceinline__ float bf16_to_f32(uint16_t bf16) { + return __uint_as_float((uint32_t)bf16 << 16); +} + +// ============================================================================ +// Warp reduction helpers +// ============================================================================ + +__device__ __forceinline__ float warp_reduce_sum(float val) { + for (int offset = 16; offset > 0; offset >>= 1) + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + return val; +} + +__device__ __forceinline__ float warp_reduce_max(float val) { + for (int offset = 16; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + return val; +} + +// ============================================================================ +// 1. 4-bit FMA dequant matvec +// ============================================================================ +// blockDim = (32, ROWS_PER_BLOCK), gridDim = ceil(out_dim / ROWS_PER_BLOCK) +// Shared memory: in_dim * sizeof(float) + +#define ROWS_PER_BLOCK 8 + +__global__ void dequant_matvec_4bit_fma( + const uint32_t* __restrict__ W_packed, + const uint16_t* __restrict__ scales, + const uint16_t* __restrict__ biases, + const float* __restrict__ x, + float* __restrict__ out, + uint32_t out_dim, + uint32_t in_dim +) { + extern __shared__ float x_shared[]; + const uint32_t lane = threadIdx.x; + const uint32_t warp_id = threadIdx.y; + const uint32_t row = blockIdx.x * ROWS_PER_BLOCK + warp_id; + const uint32_t packed_cols = in_dim / 8; + const uint32_t num_groups = in_dim / GROUP_SIZE; + const uint32_t packed_per_group = GROUP_SIZE / 8; + + // Cooperative load + const uint32_t tid = warp_id * 32 + lane; + const uint32_t total = blockDim.x * blockDim.y; + for (uint32_t i = tid; i < in_dim; i += total) + x_shared[i] = x[i]; + __syncthreads(); + + if (row >= out_dim) return; + + const uint32_t* w_row = W_packed + row * packed_cols; + const uint16_t* s_row = scales + row * num_groups; + const uint16_t* b_row = biases + row * num_groups; + + float acc = 0.0f; + for (uint32_t col = lane; col < packed_cols; col += 32) { + uint32_t g = col / packed_per_group; + float scale = bf16_to_f32(s_row[g]); + float bias = bf16_to_f32(b_row[g]); + uint32_t packed = w_row[col]; + uint32_t xb = col * 8; + + float sx0 = scale * x_shared[xb+0]; float bx0 = bias * x_shared[xb+0]; + float sx1 = scale * x_shared[xb+1]; float bx1 = bias * x_shared[xb+1]; + float sx2 = scale * x_shared[xb+2]; float bx2 = bias * x_shared[xb+2]; + float sx3 = scale * x_shared[xb+3]; float bx3 = bias * x_shared[xb+3]; + float sx4 = scale * x_shared[xb+4]; float bx4 = bias * x_shared[xb+4]; + float sx5 = scale * x_shared[xb+5]; float bx5 = bias * x_shared[xb+5]; + float sx6 = scale * x_shared[xb+6]; float bx6 = bias * x_shared[xb+6]; + float sx7 = scale * x_shared[xb+7]; float bx7 = bias * x_shared[xb+7]; + + acc += __fmaf_rn((float)((packed >> 0) & 0xF), sx0, bx0); + acc += __fmaf_rn((float)((packed >> 4) & 0xF), sx1, bx1); + acc += __fmaf_rn((float)((packed >> 8) & 0xF), sx2, bx2); + acc += __fmaf_rn((float)((packed >> 12) & 0xF), sx3, bx3); + acc += __fmaf_rn((float)((packed >> 16) & 0xF), sx4, bx4); + acc += __fmaf_rn((float)((packed >> 20) & 0xF), sx5, bx5); + acc += __fmaf_rn((float)((packed >> 24) & 0xF), sx6, bx6); + acc += __fmaf_rn((float)((packed >> 28) & 0xF), sx7, bx7); + } + + acc = warp_reduce_sum(acc); + if (lane == 0) out[row] = acc; +} + +// ============================================================================ +// 1b. 4-bit FMA dequant matvec with uint4 vectorized loads +// ============================================================================ +// Loads 4 × uint32 (128 bits) per instruction instead of 1 × uint32. +// Each uint4 = 32 nibbles = 32 input elements processed per load. +// Reduces instruction count and improves memory throughput. + +__global__ void dequant_matvec_4bit_fma_vec4( + const uint32_t* __restrict__ W_packed, + const uint16_t* __restrict__ scales, + const uint16_t* __restrict__ biases, + const float* __restrict__ x, + float* __restrict__ out, + uint32_t out_dim, + uint32_t in_dim +) { + extern __shared__ float x_shared[]; + const uint32_t lane = threadIdx.x; + const uint32_t warp_id = threadIdx.y; + const uint32_t row = blockIdx.x * ROWS_PER_BLOCK + warp_id; + // All divisions by powers of 2 → shifts. No runtime division. + const uint32_t packed_cols = in_dim >> 3; // in_dim / 8 + const uint32_t num_groups = in_dim >> 6; // in_dim / 64 (GROUP_SIZE=64) + const uint32_t vec4_cols = packed_cols >> 2; // packed_cols / 4 + + // Cooperative load — all threads must participate before barrier + const uint32_t tid = warp_id * 32 + lane; + for (uint32_t i = tid; i < in_dim; i += (32 * ROWS_PER_BLOCK)) + x_shared[i] = x[i]; + __syncthreads(); + + if (row >= out_dim) return; + + const uint4* w_row_v = (const uint4*)(W_packed + row * packed_cols); + const uint16_t* s_row = scales + row * num_groups; + const uint16_t* b_row = biases + row * num_groups; + + float acc = 0.0f; + + for (uint32_t vi = lane; vi < vec4_cols; vi += 32) { + uint4 packed4 = __ldg(w_row_v + vi); // read-through L1 cache + uint32_t base_col = vi << 2; + uint32_t x_base = base_col << 3; // base_col * 8 + + #pragma unroll + for (uint32_t w = 0; w < 4; w++) { + uint32_t packed = ((const uint32_t*)&packed4)[w]; + // group index: (base_col + w) / 8 = (base_col + w) >> 3 + // (packed_per_group = GROUP_SIZE/8 = 8) + uint32_t g = (base_col + w) >> 3; + float scale = bf16_to_f32(__ldg(s_row + g)); + float bias = bf16_to_f32(__ldg(b_row + g)); + uint32_t xb = x_base + (w << 3); // w * 8 + + float sx0 = scale * x_shared[xb+0]; float bx0 = bias * x_shared[xb+0]; + float sx1 = scale * x_shared[xb+1]; float bx1 = bias * x_shared[xb+1]; + float sx2 = scale * x_shared[xb+2]; float bx2 = bias * x_shared[xb+2]; + float sx3 = scale * x_shared[xb+3]; float bx3 = bias * x_shared[xb+3]; + float sx4 = scale * x_shared[xb+4]; float bx4 = bias * x_shared[xb+4]; + float sx5 = scale * x_shared[xb+5]; float bx5 = bias * x_shared[xb+5]; + float sx6 = scale * x_shared[xb+6]; float bx6 = bias * x_shared[xb+6]; + float sx7 = scale * x_shared[xb+7]; float bx7 = bias * x_shared[xb+7]; + + acc += __fmaf_rn((float)((packed >> 0) & 0xF), sx0, bx0); + acc += __fmaf_rn((float)((packed >> 4) & 0xF), sx1, bx1); + acc += __fmaf_rn((float)((packed >> 8) & 0xF), sx2, bx2); + acc += __fmaf_rn((float)((packed >> 12) & 0xF), sx3, bx3); + acc += __fmaf_rn((float)((packed >> 16) & 0xF), sx4, bx4); + acc += __fmaf_rn((float)((packed >> 20) & 0xF), sx5, bx5); + acc += __fmaf_rn((float)((packed >> 24) & 0xF), sx6, bx6); + acc += __fmaf_rn((float)((packed >> 28) & 0xF), sx7, bx7); + } + } + + acc = warp_reduce_sum(acc); + if (lane == 0) out[row] = acc; +} + +// ============================================================================ +// 1c. GGML Q4_K dequant matvec — native GGUF format support +// ============================================================================ +// Q4_K super-block: 256 elements, 144 bytes (4.5 bits/weight) +// d (fp16): super-block scale for quantized scales +// dmin (fp16): super-block scale for quantized mins +// scales[12]: packed 6-bit per-sub-block scales and mins (8 sub-blocks of 32) +// qs[128]: 4-bit quantized values (256 values, 2 per byte) +// Dequant: value = d * sub_scale * nibble - dmin * sub_min + +#define QK_K 256 +#define Q4_K_BLOCK_SIZE 144 // 2+2+12+128 bytes + +// Unpack 6-bit scale and min from the packed scales array +__device__ __forceinline__ void get_scale_min_k4(int j, const uint8_t *q, + uint8_t *sc, uint8_t *mn) { + if (j < 4) { + *sc = q[j] & 63; + *mn = q[j + 4] & 63; + } else { + *sc = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *mn = (q[j + 4] >> 4) | ((q[j] >> 6) << 4); + } +} + +// Q4_K matvec with Q8_K input quantization (matches llama.cpp vec_dot_q4_K_q8_K) +// Shared memory layout: [in_dim floats for x] + [in_dim int8 for q8] + [in_dim/256 floats for q8_scales] +__global__ void dequant_matvec_q4k( + const uint8_t* __restrict__ W_q4k, + const float* __restrict__ x, + float* __restrict__ out, + uint32_t out_dim, + uint32_t in_dim +) { + extern __shared__ float x_shared[]; + int8_t *q8_shared = (int8_t *)(x_shared + in_dim); + float *q8_scales = (float *)(q8_shared + in_dim); + + const uint32_t lane = threadIdx.x; + const uint32_t warp_id = threadIdx.y; + const uint32_t row = blockIdx.x * ROWS_PER_BLOCK + warp_id; + const uint32_t blocks_per_row = in_dim >> 8; // in_dim / 256 + const uint32_t tid = warp_id * 32 + lane; + + // Load x into shared memory + for (uint32_t i = tid; i < in_dim; i += (32 * ROWS_PER_BLOCK)) + x_shared[i] = x[i]; + __syncthreads(); + + // Quantize x to Q8_K: per-256-element blocks + for (uint32_t bi = tid; bi < blocks_per_row; bi += (32 * ROWS_PER_BLOCK)) { + float max_abs = 0.0f; + for (uint32_t i = 0; i < 256; i++) { + float v = fabsf(x_shared[bi * 256 + i]); + if (v > max_abs) max_abs = v; + } + float scale = max_abs / 127.0f; + float inv_scale = (scale > 0) ? (1.0f / scale) : 0.0f; + q8_scales[bi] = scale; + for (uint32_t i = 0; i < 256; i++) { + float v = x_shared[bi * 256 + i] * inv_scale; + int q = __float2int_rn(v); + q8_shared[bi * 256 + i] = (int8_t)(q < -128 ? -128 : (q > 127 ? 127 : q)); + } + } + __syncthreads(); + + if (row >= out_dim) return; + + const uint8_t *row_data = W_q4k + (size_t)row * blocks_per_row * Q4_K_BLOCK_SIZE; + float acc = 0.0f; + + for (uint32_t bi = lane; bi < blocks_per_row; bi += 32) { + const uint8_t *block = row_data + bi * Q4_K_BLOCK_SIZE; + const uint8_t *qs = block + 16; + const uint8_t *sc_bytes = block + 4; + float d_w = __half2float(__ldg((const __half *)(block))); + float dmin_w = __half2float(__ldg((const __half *)(block + 2))); + float d_q8 = q8_scales[bi]; + const int8_t *q8 = q8_shared + bi * 256; + + // Dequant Q4_K into aux8 (matching vec_dot_q4_K_q8_K_generic) + // 4 pairs of 64: low nibbles then high nibbles + int8_t aux8[256]; + int a = 0, q4_off = 0; + for (int j = 0; j < 4; j++) { + for (int l = 0; l < 32; l++) aux8[a + l] = (int8_t)(__ldg(&qs[q4_off + l]) & 0xF); + a += 32; + for (int l = 0; l < 32; l++) aux8[a + l] = (int8_t)(__ldg(&qs[q4_off + l]) >> 4); + a += 32; + q4_off += 32; + } + + // Unpack scales using utmp/kmask approach (matching vec_dot) + uint32_t utmp[4]; + utmp[0] = __ldg((const uint32_t *)(sc_bytes + 0)); + utmp[1] = __ldg((const uint32_t *)(sc_bytes + 4)); + utmp[2] = __ldg((const uint32_t *)(sc_bytes + 8)); + utmp[3] = ((utmp[2] >> 4) & 0x0f0f0f0fu) | (((utmp[1] >> 6) & 0x03030303u) << 4); + uint32_t uaux = utmp[1] & 0x3f3f3f3fu; + utmp[1] = (utmp[2] & 0x0f0f0f0fu) | (((utmp[0] >> 6) & 0x03030303u) << 4); + utmp[2] = uaux; + utmp[0] &= 0x3f3f3f3fu; + const uint8_t *scales = (const uint8_t *)&utmp[0]; + const uint8_t *mins = (const uint8_t *)&utmp[2]; + + // Compute bsums (sum of q8 per 16-element group) + int sumi = 0; + for (int j = 0; j < 16; j++) { + int bsum = 0; + for (int l = 0; l < 16; l++) bsum += q8[j * 16 + l]; + sumi += bsum * (int)mins[j / 2]; + } + + // Integer dot product (matching vec_dot inner loop) + int32_t aux32 = 0; + int ai = 0, is = 0; + for (int j = 0; j < 8; j++) { + int sc = scales[is++]; + for (int l = 0; l < 32; l++) + aux32 += sc * (int)q8[ai + l] * (int)aux8[ai + l]; + ai += 32; + } + + acc += d_w * d_q8 * (float)aux32 - dmin_w * d_q8 * (float)sumi; + } + + acc = warp_reduce_sum(acc); + if (lane == 0) out[row] = acc; +} + +// Launch helper for Q4_K format +static inline void launch_dequant_matvec_q4k( + const uint8_t* W, const float* x, float* out, + uint32_t out_dim, uint32_t in_dim, cudaStream_t stream = 0 +) { + dim3 block(32, ROWS_PER_BLOCK); + dim3 grid((out_dim + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); + // Shared memory: float x[in_dim] + int8 q8[in_dim] + float scales[in_dim/256] + size_t smem = in_dim * sizeof(float) + in_dim * sizeof(int8_t) + (in_dim / 256) * sizeof(float); + dequant_matvec_q4k<<>>(W, x, out, out_dim, in_dim); +} + +// ============================================================================ +// Q5_K dequantized matrix-vector multiply (GGML format) +// ============================================================================ +// Q5_K block (176 bytes per 256 elements, 5.5 bits/weight): +// d (fp16): super-block scale +// dmin (fp16): super-block min scale +// scales[12]: packed 6-bit per-sub-block scales and mins (same as Q4_K) +// qh[32]: high bits (5th bit) for each of 256 elements, packed 8 per byte +// qs[128]: low 4 bits, packed 2 per byte (same as Q4_K) +// Dequant: value = d * sub_scale * q5_value - dmin * sub_min +// where q5_value = (low_nibble) | (high_bit << 4), range 0-31 + +#define Q5_K_BLOCK_SIZE 176 // 2+2+12+32+128 bytes + +// Q5_K matvec with Q8_K input quantization (matches llama.cpp vec_dot_q5_K_q8_K) +// Shared memory layout: [in_dim floats for x] + [in_dim int8 for q8] + [in_dim/256 floats for q8_scales] +__global__ void dequant_matvec_q5k( + const uint8_t* __restrict__ W_q5k, + const float* __restrict__ x, + float* __restrict__ out, + uint32_t out_dim, + uint32_t in_dim +) { + extern __shared__ float x_shared[]; + // Layout shared mem: [in_dim float] then [in_dim int8] then [in_dim/256 float scales] + int8_t *q8_shared = (int8_t *)(x_shared + in_dim); + float *q8_scales = (float *)(q8_shared + in_dim); + + const uint32_t lane = threadIdx.x; + const uint32_t warp_id = threadIdx.y; + const uint32_t row = blockIdx.x * ROWS_PER_BLOCK + warp_id; + const uint32_t blocks_per_row = in_dim >> 8; + const uint32_t tid = warp_id * 32 + lane; + + // Load x into shared memory + for (uint32_t i = tid; i < in_dim; i += (32 * ROWS_PER_BLOCK)) + x_shared[i] = x[i]; + __syncthreads(); + + // Quantize x to Q8_K: per-256-element blocks + for (uint32_t bi = tid; bi < blocks_per_row; bi += (32 * ROWS_PER_BLOCK)) { + float max_abs = 0.0f; + for (uint32_t i = 0; i < 256; i++) { + float v = fabsf(x_shared[bi * 256 + i]); + if (v > max_abs) max_abs = v; + } + float scale = max_abs / 127.0f; + float inv_scale = (scale > 0) ? (1.0f / scale) : 0.0f; + q8_scales[bi] = scale; + for (uint32_t i = 0; i < 256; i++) { + float v = x_shared[bi * 256 + i] * inv_scale; + int q = __float2int_rn(v); + q8_shared[bi * 256 + i] = (int8_t)(q < -128 ? -128 : (q > 127 ? 127 : q)); + } + } + __syncthreads(); + + if (row >= out_dim) return; + + const uint8_t *row_data = W_q5k + (size_t)row * blocks_per_row * Q5_K_BLOCK_SIZE; + float acc = 0.0f; + + for (uint32_t bi = lane; bi < blocks_per_row; bi += 32) { + const uint8_t *block = row_data + bi * Q5_K_BLOCK_SIZE; + const uint8_t *qs = block + 48; + const uint8_t *qh = block + 16; + const uint8_t *sc_bytes = block + 4; + float d_w = __half2float(__ldg((const __half *)(block))); + float dmin_w = __half2float(__ldg((const __half *)(block + 2))); + float d_q8 = q8_scales[bi]; + const int8_t *q8 = q8_shared + bi * 256; + + // Dequant Q5K into aux8 (matching vec_dot_q5_K_q8_K_generic) + int8_t aux8[256]; + uint8_t m = 1; + int a = 0, q4_off = 0; + for (int j = 0; j < 4; j++) { + for (int l = 0; l < 32; l++) aux8[a + l] = qs[q4_off + l] & 0xF; + for (int l = 0; l < 32; l++) aux8[a + l] += (__ldg(&qh[l]) & m) ? 16 : 0; + a += 32; m <<= 1; + for (int l = 0; l < 32; l++) aux8[a + l] = qs[q4_off + l] >> 4; + for (int l = 0; l < 32; l++) aux8[a + l] += (__ldg(&qh[l]) & m) ? 16 : 0; + a += 32; m <<= 1; + q4_off += 32; + } + + // Unpack scales using utmp/kmask approach (matching vec_dot) + uint32_t utmp[4]; + utmp[0] = __ldg((const uint32_t *)(sc_bytes + 0)); + utmp[1] = __ldg((const uint32_t *)(sc_bytes + 4)); + utmp[2] = __ldg((const uint32_t *)(sc_bytes + 8)); + utmp[3] = ((utmp[2] >> 4) & 0x0f0f0f0fu) | (((utmp[1] >> 6) & 0x03030303u) << 4); + uint32_t uaux = utmp[1] & 0x3f3f3f3fu; + utmp[1] = (utmp[2] & 0x0f0f0f0fu) | (((utmp[0] >> 6) & 0x03030303u) << 4); + utmp[2] = uaux; + utmp[0] &= 0x3f3f3f3fu; + const uint8_t *scales = (const uint8_t *)&utmp[0]; + const uint8_t *mins = (const uint8_t *)&utmp[2]; + + // Compute bsums (sum of q8 per 16-element group) + int sumi = 0; + for (int j = 0; j < 16; j++) { + int bsum = 0; + for (int l = 0; l < 16; l++) bsum += q8[j * 16 + l]; + sumi += bsum * (int)mins[j / 2]; + } + + // Integer dot product (matching vec_dot inner loop) + int32_t aux32 = 0; + int ai = 0, is = 0; + for (int j = 0; j < 8; j++) { + int sc = scales[is++]; + for (int l = 0; l < 32; l++) + aux32 += sc * (int)q8[ai + l] * (int)aux8[ai + l]; + ai += 32; + } + + acc += d_w * d_q8 * (float)aux32 - dmin_w * d_q8 * (float)sumi; + } + + acc = warp_reduce_sum(acc); + if (lane == 0) out[row] = acc; +} + +static inline void launch_dequant_matvec_q5k( + const uint8_t* W, const float* x, float* out, + uint32_t out_dim, uint32_t in_dim, cudaStream_t stream = 0 +) { + dim3 block(32, ROWS_PER_BLOCK); + dim3 grid((out_dim + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); + // Shared memory: float x[in_dim] + int8 q8[in_dim] + float scales[in_dim/256] + size_t smem = in_dim * sizeof(float) + in_dim * sizeof(int8_t) + (in_dim / 256) * sizeof(float); + dequant_matvec_q5k<<>>(W, x, out, out_dim, in_dim); +} + +// ============================================================================ +// 2. SwiGLU: out[i] = SiLU(gate[i]) * up[i] +// ============================================================================ + +__global__ void swiglu_fused( + const float* __restrict__ gate, + const float* __restrict__ up, + float* __restrict__ out, + uint32_t dim +) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= dim) return; + float g = gate[i]; + out[i] = (g / (1.0f + expf(-g))) * up[i]; +} + +// ============================================================================ +// 3. RMS Norm (fused: sum_sq + normalize in one kernel) +// ============================================================================ +// blockDim = 256 (or 1024), gridDim = 1 +// Shared memory: 32 * sizeof(float) + +__global__ void rms_norm( + const float* __restrict__ x, + const float* __restrict__ weight, // f32 weights + float* __restrict__ out, + uint32_t dim, + float eps +) { + __shared__ float shared[32]; + float acc = 0.0f; + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { + float v = x[i]; + acc += v * v; + } + // Warp reduce + acc = warp_reduce_sum(acc); + uint32_t wid = threadIdx.x / 32; + uint32_t lane = threadIdx.x % 32; + if (lane == 0) shared[wid] = acc; + __syncthreads(); + + if (wid == 0) { + acc = (lane < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + acc = warp_reduce_sum(acc); + if (lane == 0) shared[0] = acc; + } + __syncthreads(); + + float rms = rsqrtf(shared[0] / (float)dim + eps); + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) + out[i] = x[i] * rms * weight[i]; +} + +// ============================================================================ +// 4. RMS Norm with bf16 weights +// ============================================================================ + +__global__ void rms_norm_bf16( + const float* __restrict__ x, + const uint16_t* __restrict__ weight, + float* __restrict__ out, + uint32_t dim, + float eps +) { + __shared__ float shared[32]; + float acc = 0.0f; + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) + acc += x[i] * x[i]; + + acc = warp_reduce_sum(acc); + uint32_t wid = threadIdx.x / 32; + uint32_t lane = threadIdx.x % 32; + if (lane == 0) shared[wid] = acc; + __syncthreads(); + if (wid == 0) { + acc = (lane < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + acc = warp_reduce_sum(acc); + if (lane == 0) shared[0] = acc; + } + __syncthreads(); + + float rms = rsqrtf(shared[0] / (float)dim + eps); + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) + out[i] = x[i] * rms * bf16_to_f32(weight[i]); +} + +// ============================================================================ +// 5. Residual add: out = a + b +// ============================================================================ + +__global__ void residual_add( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ out, + uint32_t dim +) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < dim) out[i] = a[i] + b[i]; +} + +// ============================================================================ +// 6. Attention scores: Q @ K^T (batched over heads) +// ============================================================================ +// Grid: (seq_len * num_heads), Block: 256 +// GQA: heads_per_kv query heads share one KV head + +__global__ void attn_scores( + const float* __restrict__ Q, // [num_heads, head_dim] + const float* __restrict__ K_cache, // [max_seq, kv_dim] + float* __restrict__ scores, // [num_heads, seq_stride] + uint32_t head_dim, uint32_t kv_dim, uint32_t seq_len, + uint32_t seq_stride, float scale, uint32_t heads_per_kv, uint32_t num_seq_tgs +) { + __shared__ float shared[32]; + uint32_t pos = blockIdx.x % num_seq_tgs; + uint32_t h = blockIdx.x / num_seq_tgs; + if (pos >= seq_len) return; + + uint32_t kv_h = h / heads_per_kv; + const float* qh = Q + h * head_dim; + const float* kp = K_cache + pos * kv_dim + kv_h * head_dim; + + float acc = 0.0f; + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) + acc += qh[d] * kp[d]; + + acc = warp_reduce_sum(acc); + uint32_t wid = threadIdx.x / 32, lane = threadIdx.x % 32; + if (lane == 0) shared[wid] = acc; + __syncthreads(); + if (wid == 0) { + acc = (lane < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + acc = warp_reduce_sum(acc); + if (lane == 0) scores[h * seq_stride + pos] = acc * scale; + } +} + +// ============================================================================ +// 7. Softmax over seq_len per head +// ============================================================================ + +__global__ void attn_softmax( + float* __restrict__ scores, // [num_heads, seq_stride] + uint32_t seq_len, uint32_t seq_stride +) { + __shared__ float s_max, s_sum; + float* s = scores + blockIdx.x * seq_stride; + + // Max + float local_max = -1e30f; + for (uint32_t i = threadIdx.x; i < seq_len; i += blockDim.x) + local_max = fmaxf(local_max, s[i]); + + __shared__ float shared[32]; + float wmax = warp_reduce_max(local_max); + uint32_t wid = threadIdx.x / 32, lane = threadIdx.x % 32; + if (lane == 0) shared[wid] = wmax; + __syncthreads(); + if (wid == 0) { + wmax = (lane < (blockDim.x + 31) / 32) ? shared[lane] : -1e30f; + wmax = warp_reduce_max(wmax); + if (lane == 0) s_max = wmax; + } + __syncthreads(); + + // Exp + sum + float local_sum = 0.0f; + for (uint32_t i = threadIdx.x; i < seq_len; i += blockDim.x) { + float v = expf(s[i] - s_max); + s[i] = v; + local_sum += v; + } + float wsum = warp_reduce_sum(local_sum); + if (lane == 0) shared[wid] = wsum; + __syncthreads(); + if (wid == 0) { + wsum = (lane < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + wsum = warp_reduce_sum(wsum); + if (lane == 0) s_sum = wsum; + } + __syncthreads(); + + // Normalize + float inv = 1.0f / s_sum; + for (uint32_t i = threadIdx.x; i < seq_len; i += blockDim.x) + s[i] *= inv; +} + +// ============================================================================ +// 8. Attention values: scores @ V +// ============================================================================ + +__global__ void attn_values( + const float* __restrict__ scores, // [num_heads, seq_stride] + const float* __restrict__ V_cache, // [max_seq, kv_dim] + float* __restrict__ out, // [num_heads, head_dim] + uint32_t head_dim, uint32_t kv_dim, uint32_t seq_len, + uint32_t seq_stride, uint32_t heads_per_kv +) { + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t d = tid % head_dim; + uint32_t h = tid / head_dim; + uint32_t kv_h = h / heads_per_kv; + const float* sc = scores + h * seq_stride; + + float acc = 0.0f; + for (uint32_t p = 0; p < seq_len; p++) + acc += sc[p] * V_cache[p * kv_dim + kv_h * head_dim + d]; + out[h * head_dim + d] = acc; +} + +// ============================================================================ +// 9. Sigmoid gate: x *= sigmoid(gate) +// ============================================================================ + +__global__ void sigmoid_gate( + float* __restrict__ x_out, + const float* __restrict__ gate, + uint32_t dim +) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < dim) { + float g = 1.0f / (1.0f + expf(-gate[i])); + x_out[i] *= g; + } +} + +// ============================================================================ +// 10. GatedDeltaNet step (single token, all heads) +// ============================================================================ +// Grid: 64 (v-heads), Block: 128 (value_dim) + +__global__ void gated_delta_net_step( + float* __restrict__ state, // [64 * 128 * 128] + const float* __restrict__ q, // [2048] + const float* __restrict__ k, // [2048] + const float* __restrict__ v, // [8192] + const float* __restrict__ g_decay, // [64] + const float* __restrict__ beta_gate, // [64] + float* __restrict__ output, // [8192] + uint32_t k_heads_per_v // = 4 +) { + uint32_t head_id = blockIdx.x; + uint32_t vi = threadIdx.x; + uint32_t kh = head_id / k_heads_per_v; + float g = g_decay[head_id]; + float beta = beta_gate[head_id]; + + uint32_t state_base = head_id * 128 * 128 + vi * 128; + uint32_t k_base = kh * 128; + uint32_t v_base = head_id * 128; + + // Decay + memory read + float kv_mem = 0.0f; + for (uint32_t ki = 0; ki < 128; ki++) { + float s = state[state_base + ki] * g; + state[state_base + ki] = s; + kv_mem += s * k[k_base + ki]; + } + + // Delta update + float delta = (v[v_base + vi] - kv_mem) * beta; + for (uint32_t ki = 0; ki < 128; ki++) + state[state_base + ki] += k[k_base + ki] * delta; + + // Output + float out_val = 0.0f; + for (uint32_t ki = 0; ki < 128; ki++) + out_val += state[state_base + ki] * q[k_base + ki]; + output[v_base + vi] = out_val; +} + +// ============================================================================ +// 11. Conv1d step (kernel=4, depthwise, with SiLU) +// ============================================================================ + +__global__ void conv1d_step( + float* __restrict__ conv_state, // [3 * conv_dim] + const float* __restrict__ input, // [conv_dim] + const uint16_t* __restrict__ weights, // [conv_dim * 4] bf16 + float* __restrict__ output, // [conv_dim] + uint32_t conv_dim +) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= conv_dim) return; + + uint32_t wb = idx * 4; + float acc = conv_state[0 * conv_dim + idx] * bf16_to_f32(weights[wb + 0]) + + conv_state[1 * conv_dim + idx] * bf16_to_f32(weights[wb + 1]) + + conv_state[2 * conv_dim + idx] * bf16_to_f32(weights[wb + 2]); + float inp = input[idx]; + acc += inp * bf16_to_f32(weights[wb + 3]); + + output[idx] = acc / (1.0f + expf(-acc)); // SiLU + + // Shift history + conv_state[0 * conv_dim + idx] = conv_state[1 * conv_dim + idx]; + conv_state[1 * conv_dim + idx] = conv_state[2 * conv_dim + idx]; + conv_state[2 * conv_dim + idx] = inp; +} + +// ============================================================================ +// 12a. Per-head L2 norm for Q and K (GGUF models — matches llama.cpp ggml_l2_norm) +// ============================================================================ + +__global__ void l2_norm_qk( + float* __restrict__ q, + float* __restrict__ k, + uint32_t key_dim +) { + uint32_t head = blockIdx.x; + uint32_t tid = threadIdx.x; + uint32_t base = head * key_dim; + + __shared__ float buf[128]; + + // Q L2 norm: q = q / ||q|| + float qv = (tid < key_dim) ? q[base + tid] : 0.0f; + buf[tid] = qv * qv; + __syncthreads(); + __shared__ float q_sum; + if (tid == 0) { float s = 0; for (uint32_t i = 0; i < key_dim; i++) s += buf[i]; q_sum = s; } + __syncthreads(); + if (tid < key_dim) + q[base + tid] = qv * rsqrtf(q_sum + 1e-6f); + + // K L2 norm: k = k / ||k|| + float kv = (tid < key_dim) ? k[base + tid] : 0.0f; + buf[tid] = kv * kv; + __syncthreads(); + __shared__ float k_sum; + if (tid == 0) { float s = 0; for (uint32_t i = 0; i < key_dim; i++) s += buf[i]; k_sum = s; } + __syncthreads(); + if (tid < key_dim) + k[base + tid] = kv * rsqrtf(k_sum + 1e-6f); +} + +// ============================================================================ +// 12b. Per-head RMS norm for Q and K (MLX models — original 397B behavior) +// ============================================================================ + +__global__ void rms_norm_qk( + float* __restrict__ q, + float* __restrict__ k, + uint32_t key_dim, float inv_scale +) { + uint32_t head = blockIdx.x; + uint32_t tid = threadIdx.x; + uint32_t base = head * key_dim; + + // Q norm + __shared__ float q_sum; + __shared__ float buf[128]; + float qv = (tid < key_dim) ? q[base + tid] : 0.0f; + buf[tid] = qv * qv; + __syncthreads(); + if (tid == 0) { float s = 0; for (uint32_t i = 0; i < key_dim; i++) s += buf[i]; q_sum = s; } + __syncthreads(); + if (tid < key_dim) + q[base + tid] = qv * rsqrtf(q_sum / (float)key_dim + 1e-6f) * inv_scale * inv_scale; + + // K norm + __shared__ float k_sum; + float kv = (tid < key_dim) ? k[base + tid] : 0.0f; + buf[tid] = kv * kv; + __syncthreads(); + if (tid == 0) { float s = 0; for (uint32_t i = 0; i < key_dim; i++) s += buf[i]; k_sum = s; } + __syncthreads(); + if (tid < key_dim) + k[base + tid] = kv * rsqrtf(k_sum / (float)key_dim + 1e-6f) * inv_scale; +} + +// ============================================================================ +// 13. Compute decay and beta gate for GatedDeltaNet +// ============================================================================ + +// MLX version: A_log stores log(A), compute exp(A_log) first, dt_bias is bf16 +__global__ void compute_decay_beta( + const float* __restrict__ alpha_out, + const float* __restrict__ beta_out, + const float* __restrict__ A_log, + const uint16_t* __restrict__ dt_bias, + float* __restrict__ g_decay, + float* __restrict__ beta_gate +) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + float a_val = alpha_out[idx]; + float dt_b = bf16_to_f32(dt_bias[idx]); + float A_val = expf(A_log[idx]); + float sp = logf(1.0f + expf(a_val + dt_b)); + g_decay[idx] = expf(-A_val * sp); + beta_gate[idx] = 1.0f / (1.0f + expf(-beta_out[idx])); +} + +// GGUF version: ssm_a and dt_bias are both F32 +__global__ void compute_decay_beta_gguf( + const float* __restrict__ alpha_out, + const float* __restrict__ beta_out, + const float* __restrict__ ssm_a, // negative values, used directly + const float* __restrict__ dt_bias, // F32 (not converted to bf16) + float* __restrict__ g_decay, + float* __restrict__ beta_gate +) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + float a_val = alpha_out[idx]; + float dt_b = dt_bias[idx]; + float sp = logf(1.0f + expf(a_val + dt_b)); // softplus(alpha + dt_bias) + g_decay[idx] = expf(ssm_a[idx] * sp); // exp(ssm_a * softplus) — ssm_a is negative + beta_gate[idx] = 1.0f / (1.0f + expf(-beta_out[idx])); // sigmoid(beta) +} + +// ============================================================================ +// 14. Gated RMS norm: rms_norm(values) * SiLU(z) * weight +// ============================================================================ + +__global__ void gated_rms_norm( + const float* __restrict__ values, + const float* __restrict__ z, + const uint16_t* __restrict__ weight, + float* __restrict__ output, + uint32_t value_dim, float eps +) { + uint32_t head = blockIdx.x; + uint32_t tid = threadIdx.x; + uint32_t base = head * value_dim; + + __shared__ float buf[128]; + float val = (tid < value_dim) ? values[base + tid] : 0.0f; + buf[tid] = val * val; + __syncthreads(); + if (tid == 0) { float s = 0; for (uint32_t i = 0; i < value_dim; i++) s += buf[i]; buf[0] = s; } + __syncthreads(); + float inv_rms = rsqrtf(buf[0] / (float)value_dim + eps); + + if (tid < value_dim) { + float normed = val * inv_rms; + float zv = z[base + tid]; + float gate = zv / (1.0f + expf(-zv)); // SiLU + output[base + tid] = normed * gate * bf16_to_f32(weight[tid]); + } +} + +// ============================================================================ +// 15. MoE combine + residual + shared expert gate +// ============================================================================ +// out[i] = h_mid[i] + sum_k(weight[k] * expert_out[k*dim+i]) + sigmoid(shared_gate) * shared[i] + +__global__ void moe_combine_residual( + const float* __restrict__ h_mid, + const float* __restrict__ shared_out, + float* __restrict__ hidden_out, + const float* __restrict__ expert_outs, // [K * dim] concatenated + const float* __restrict__ weights, // [K] + float shared_gate_score, + uint32_t dim, uint32_t K +) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= dim) return; + + float sg = 1.0f / (1.0f + expf(-shared_gate_score)); + float moe = 0.0f; + for (uint32_t k = 0; k < K; k++) + moe += weights[k] * expert_outs[k * dim + i]; + + hidden_out[i] = h_mid[i] + moe + sg * shared_out[i]; +} + +// ============================================================================ +// Launch helpers +// ============================================================================ + +static inline void launch_dequant_matvec( + const uint32_t* W, const uint16_t* scales, const uint16_t* biases, + const float* x, float* out, uint32_t out_dim, uint32_t in_dim, + cudaStream_t stream = 0 +) { + dim3 block(32, ROWS_PER_BLOCK); + dim3 grid((out_dim + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); + size_t smem = in_dim * sizeof(float); + dequant_matvec_4bit_fma_vec4<<>>(W, scales, biases, x, out, out_dim, in_dim); +} + +// ============================================================================ +// Q6_K dequantized matrix-vector multiply (GGML format) +// ============================================================================ +// Q6_K block (210 bytes per 256 elements, 6.5625 bits/weight): +// ql[128]: low 4 bits of each 6-bit value (2 per byte, low nibble first) +// qh[64]: high 2 bits of each 6-bit value (4 values per byte) +// scales[16]: int8 per-sub-block scales (16 sub-blocks of 16 elements) +// d (fp16): super-block scale +// Dequant: value = d * scale * (q6_value - 32) where q6_value is 0-63 (6-bit unsigned) + +#define Q6_K_BLOCK_SIZE 210 // 128+64+16+2 bytes + +// Q6_K matvec with Q8_K input quantization (matches llama.cpp vec_dot_q6_K_q8_K) +// Shared memory layout: [in_dim floats for x] + [in_dim int8 for q8] + [in_dim/256 floats for q8_scales] +__global__ void dequant_matvec_q6k( + const uint8_t* __restrict__ W_q6k, + const float* __restrict__ x, + float* __restrict__ out, + uint32_t out_dim, + uint32_t in_dim +) { + extern __shared__ float x_shared[]; + int8_t *q8_shared = (int8_t *)(x_shared + in_dim); + float *q8_scales = (float *)(q8_shared + in_dim); + + const uint32_t lane = threadIdx.x; + const uint32_t warp_id = threadIdx.y; + const uint32_t row = blockIdx.x * ROWS_PER_BLOCK + warp_id; + const uint32_t blocks_per_row = in_dim >> 8; // in_dim / 256 + const uint32_t tid = warp_id * 32 + lane; + + // Load x into shared memory + for (uint32_t i = tid; i < in_dim; i += (32 * ROWS_PER_BLOCK)) + x_shared[i] = x[i]; + __syncthreads(); + + // Quantize x to Q8_K: per-256-element blocks + for (uint32_t bi = tid; bi < blocks_per_row; bi += (32 * ROWS_PER_BLOCK)) { + float max_abs = 0.0f; + for (uint32_t i = 0; i < 256; i++) { + float v = fabsf(x_shared[bi * 256 + i]); + if (v > max_abs) max_abs = v; + } + float scale = max_abs / 127.0f; + float inv_scale = (scale > 0) ? (1.0f / scale) : 0.0f; + q8_scales[bi] = scale; + for (uint32_t i = 0; i < 256; i++) { + float v = x_shared[bi * 256 + i] * inv_scale; + int q = __float2int_rn(v); + q8_shared[bi * 256 + i] = (int8_t)(q < -128 ? -128 : (q > 127 ? 127 : q)); + } + } + __syncthreads(); + + if (row >= out_dim) return; + + const uint8_t *row_data = W_q6k + (size_t)row * blocks_per_row * Q6_K_BLOCK_SIZE; + float acc = 0.0f; + + for (uint32_t bi = lane; bi < blocks_per_row; bi += 32) { + const uint8_t *block = row_data + bi * Q6_K_BLOCK_SIZE; + const uint8_t *ql = block; + const uint8_t *qh = block + 128; + const int8_t *sc = (const int8_t *)(block + 192); + float d_w = __half2float(__ldg((const __half *)(block + 208))); + float d_q8 = q8_scales[bi]; + const int8_t *q8 = q8_shared + bi * 256; + + // Dequant Q6_K into aux8 (matching vec_dot_q6_K_q8_K_generic) + // 2 halves of 128 elements: ql advances by 64, qh by 32 per half + // qh shifts are always 0,2,4,6 (reset each half) + int8_t aux8[256]; + { + const uint8_t *q4 = ql; + const uint8_t *qh_p = qh; + int8_t *a = aux8; + for (int j = 0; j < 256; j += 128) { + for (int l = 0; l < 32; l++) { + uint8_t qlv0 = __ldg(&q4[l]); + uint8_t qlv32 = __ldg(&q4[l + 32]); + uint8_t qhv = __ldg(&qh_p[l]); + a[l + 0] = (int8_t)(((qlv0 & 0xF) | (((qhv >> 0) & 3) << 4)) - 32); + a[l + 32] = (int8_t)(((qlv32 & 0xF) | (((qhv >> 2) & 3) << 4)) - 32); + a[l + 64] = (int8_t)(((qlv0 >> 4) | (((qhv >> 4) & 3) << 4)) - 32); + a[l + 96] = (int8_t)(((qlv32 >> 4) | (((qhv >> 6) & 3) << 4)) - 32); + } + a += 128; + q4 += 64; + qh_p += 32; + } + } + + // Integer dot product (matching vec_dot_q6_K_q8_K inner loop) + // 16 sub-blocks of 16 elements, each with int8 scale + int32_t aux32 = 0; + int ai = 0, is = 0; + for (int j = 0; j < 16; j++) { + int scale = sc[is++]; + for (int l = 0; l < 16; l++) + aux32 += scale * (int)q8[ai + l] * (int)aux8[ai + l]; + ai += 16; + } + + acc += d_w * d_q8 * (float)aux32; + } + + acc = warp_reduce_sum(acc); + if (lane == 0) out[row] = acc; +} + +static inline void launch_dequant_matvec_q6k( + const uint8_t* W, const float* x, float* out, + uint32_t out_dim, uint32_t in_dim, cudaStream_t stream = 0 +) { + dim3 block(32, ROWS_PER_BLOCK); + dim3 grid((out_dim + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); + // Shared memory: float x[in_dim] + int8 q8[in_dim] + float scales[in_dim/256] + size_t smem = in_dim * sizeof(float) + in_dim * sizeof(int8_t) + (in_dim / 256) * sizeof(float); + dequant_matvec_q6k<<>>(W, x, out, out_dim, in_dim); +} + +// ============================================================================ +// F32 matrix-vector multiply (for unquantized weights like norms, small tensors) +// ============================================================================ + +__global__ void matvec_f32( + const float* __restrict__ W, + const float* __restrict__ x, + float* __restrict__ out, + uint32_t out_dim, + uint32_t in_dim +) { + extern __shared__ float x_shared[]; + const uint32_t lane = threadIdx.x; + const uint32_t warp_id = threadIdx.y; + const uint32_t row = blockIdx.x * ROWS_PER_BLOCK + warp_id; + + const uint32_t tid = warp_id * 32 + lane; + for (uint32_t i = tid; i < in_dim; i += (32 * ROWS_PER_BLOCK)) + x_shared[i] = x[i]; + __syncthreads(); + + if (row >= out_dim) return; + + const float *row_data = W + (size_t)row * in_dim; + float acc = 0.0f; + + for (uint32_t i = lane; i < in_dim; i += 32) + acc += row_data[i] * x_shared[i]; + + acc = warp_reduce_sum(acc); + if (lane == 0) out[row] = acc; +} + +static inline void launch_matvec_f32( + const float* W, const float* x, float* out, + uint32_t out_dim, uint32_t in_dim, cudaStream_t stream = 0 +) { + dim3 block(32, ROWS_PER_BLOCK); + dim3 grid((out_dim + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); + size_t smem = in_dim * sizeof(float); + matvec_f32<<>>(W, x, out, out_dim, in_dim); +} + +// ============================================================================ +// Format-aware matvec dispatch for GGUF support +// ============================================================================ +// GGML quant type IDs (from ggml-common.h) +#define GGUF_TYPE_F32 0 +#define GGUF_TYPE_F16 1 +#define GGUF_TYPE_Q4_K 12 +#define GGUF_TYPE_Q5_K 13 +#define GGUF_TYPE_Q6_K 14 + +static inline void launch_dequant_matvec_gguf( + const void* W, const float* x, float* out, + uint32_t out_dim, uint32_t in_dim, int gguf_type, cudaStream_t stream = 0 +) { + switch (gguf_type) { + case GGUF_TYPE_Q4_K: + launch_dequant_matvec_q4k((const uint8_t*)W, x, out, out_dim, in_dim, stream); + break; + case GGUF_TYPE_Q5_K: + launch_dequant_matvec_q5k((const uint8_t*)W, x, out, out_dim, in_dim, stream); + break; + case GGUF_TYPE_Q6_K: + launch_dequant_matvec_q6k((const uint8_t*)W, x, out, out_dim, in_dim, stream); + break; + case GGUF_TYPE_F32: + launch_matvec_f32((const float*)W, x, out, out_dim, in_dim, stream); + break; + default: + // Unsupported type — fall back to Q4_K as best guess + launch_dequant_matvec_q4k((const uint8_t*)W, x, out, out_dim, in_dim, stream); + break; + } +} + +static inline void launch_swiglu(const float* gate, const float* up, float* out, uint32_t dim, cudaStream_t s = 0) { + swiglu_fused<<<(dim+255)/256, 256, 0, s>>>(gate, up, out, dim); +} + +static inline void launch_rms_norm_bf16(const float* x, const uint16_t* w, float* out, uint32_t dim, float eps, cudaStream_t s = 0) { + rms_norm_bf16<<<1, 256, 0, s>>>(x, w, out, dim, eps); +} + +static inline void launch_residual_add(const float* a, const float* b, float* out, uint32_t dim, cudaStream_t s = 0) { + residual_add<<<(dim+255)/256, 256, 0, s>>>(a, b, out, dim); +} diff --git a/cuda_infer/tokenizer_impl.c b/cuda_infer/tokenizer_impl.c new file mode 100644 index 0000000..182bad0 --- /dev/null +++ b/cuda_infer/tokenizer_impl.c @@ -0,0 +1,2 @@ +#define TOKENIZER_IMPL +#include "../metal_infer/tokenizer.h" diff --git a/gguf_extract.py b/gguf_extract.py new file mode 100644 index 0000000..9cd90ff --- /dev/null +++ b/gguf_extract.py @@ -0,0 +1,545 @@ +#!/usr/bin/env python3 +"""Extract weights from a GGUF MoE model for Flash-MoE inference. + +Reads a .gguf file and produces: + 1. model_weights.bin — non-expert weights (contiguous binary, 64-byte aligned) + 2. model_weights.json — manifest with tensor offsets, shapes, dtypes, and model config + 3. packed_experts/layer_XX.bin — per-layer expert binaries + 4. vocab.bin — vocabulary for token decoding + +Usage: + python gguf_extract.py --gguf model.gguf --output ./model_dir + python gguf_extract.py --gguf model.gguf --output ./model_dir --dry-run +""" + +import argparse +import json +import os +import struct +import sys +import time +from pathlib import Path + +# ============================================================================ +# GGUF constants +# ============================================================================ +GGUF_MAGIC = 0x46554747 +GGUF_DEFAULT_ALIGNMENT = 32 + +# Quant type → (block_elements, block_bytes) +QUANT_SIZES = { + 0: (1, 4), # F32 + 1: (1, 2), # F16 + 2: (32, 18), # Q4_0 + 3: (32, 20), # Q4_1 + 6: (32, 22), # Q5_0 + 7: (32, 24), # Q5_1 + 8: (32, 34), # Q8_0 + 9: (32, 40), # Q8_1 + 10: (256, 84), # Q2_K + 11: (256, 110), # Q3_K + 12: (256, 144), # Q4_K + 13: (256, 176), # Q5_K + 14: (256, 210), # Q6_K + 15: (256, 292), # Q8_K + 24: (1, 1), # I8 + 25: (1, 2), # I16 + 26: (1, 4), # I32 + 27: (1, 8), # I64 + 28: (1, 8), # F64 + 30: (1, 2), # BF16 +} + +QUANT_NAMES = { + 0: "F32", 1: "F16", 2: "Q4_0", 3: "Q4_1", 6: "Q5_0", 7: "Q5_1", + 8: "Q8_0", 9: "Q8_1", 10: "Q2_K", 11: "Q3_K", 12: "Q4_K", + 13: "Q5_K", 14: "Q6_K", 15: "Q8_K", 30: "BF16", +} + + +def tensor_nbytes(n_elements, quant_type): + block_elems, block_bytes = QUANT_SIZES[quant_type] + return (n_elements // block_elems) * block_bytes + + +# ============================================================================ +# GGUF → engine tensor name mapping +# ============================================================================ +# full_attention_interval=4: layers 3, 7, 11, ... are full attention (0-indexed) + +def map_tensor_name(gguf_name, full_attn_interval=4): + """Map a GGUF tensor name to the engine's expected name. + + Returns the mapped name, or the original name if no mapping applies. + """ + # Global tensors + if gguf_name == 'token_embd.weight': + return 'model.embed_tokens.weight' + if gguf_name == 'output.weight': + return 'lm_head.weight' + if gguf_name == 'output_norm.weight': + return 'model.norm.weight' + + # Per-layer tensors: blk.{L}.xxx + if not gguf_name.startswith('blk.'): + return gguf_name + + parts = gguf_name.split('.') + layer = int(parts[1]) + rest = '.'.join(parts[2:]) # e.g. "attn_norm.weight" + is_full = ((layer + 1) % full_attn_interval == 0) + + # Common mappings (both layer types) + mapping = { + # Norms + 'attn_norm.weight': f'model.layers.{layer}.input_layernorm.weight', + 'post_attention_norm.weight': f'model.layers.{layer}.post_attention_layernorm.weight', + # MoE routing + shared expert + 'ffn_gate_inp.weight': f'model.layers.{layer}.mlp.gate.weight', + 'ffn_gate_shexp.weight': f'model.layers.{layer}.mlp.shared_expert.gate_proj.weight', + 'ffn_up_shexp.weight': f'model.layers.{layer}.mlp.shared_expert.up_proj.weight', + 'ffn_down_shexp.weight': f'model.layers.{layer}.mlp.shared_expert.down_proj.weight', + 'ffn_gate_inp_shexp.weight': f'model.layers.{layer}.mlp.shared_expert_gate.weight', + } + + if rest in mapping: + return mapping[rest] + + if is_full: + # Full attention layers have SEPARATE Q/K/V (not fused) + full_mapping = { + 'attn_q.weight': f'model.layers.{layer}.self_attn.q_proj.weight', + 'attn_k.weight': f'model.layers.{layer}.self_attn.k_proj.weight', + 'attn_v.weight': f'model.layers.{layer}.self_attn.v_proj.weight', + 'attn_output.weight': f'model.layers.{layer}.self_attn.o_proj.weight', + 'attn_q_norm.weight': f'model.layers.{layer}.self_attn.q_norm.weight', + 'attn_k_norm.weight': f'model.layers.{layer}.self_attn.k_norm.weight', + } + if rest in full_mapping: + return full_mapping[rest] + else: + # Linear attention layers have fused QKV + SSM parameters + linear_mapping = { + 'attn_qkv.weight': f'model.layers.{layer}.linear_attn.in_proj_qkv.weight', + 'attn_output.weight': f'model.layers.{layer}.self_attn.o_proj.weight', + # attn_gate in linear layers is the Z/output gate (maps to in_proj_z) + 'attn_gate.weight': f'model.layers.{layer}.linear_attn.in_proj_z.weight', + 'attn_q_norm.weight': f'model.layers.{layer}.self_attn.q_norm.weight', + 'attn_k_norm.weight': f'model.layers.{layer}.self_attn.k_norm.weight', + # GatedDeltaNet / SSM + 'ssm_conv1d.weight': f'model.layers.{layer}.linear_attn.conv1d.weight', + 'ssm_a': f'model.layers.{layer}.linear_attn.A_log', + 'ssm_dt.bias': f'model.layers.{layer}.linear_attn.dt_bias', + 'ssm_norm.weight': f'model.layers.{layer}.linear_attn.norm.weight', + 'ssm_out.weight': f'model.layers.{layer}.linear_attn.out_proj.weight', + 'ssm_alpha.weight': f'model.layers.{layer}.linear_attn.in_proj_a.weight', + 'ssm_beta.weight': f'model.layers.{layer}.linear_attn.in_proj_b.weight', + } + if rest in linear_mapping: + return linear_mapping[rest] + + return gguf_name + + +# ============================================================================ +# GGUF parser +# ============================================================================ +class GGUFReader: + def __init__(self, path): + self.path = path + self.fd = open(path, 'rb') + self.metadata = {} + self.tensors = [] # list of {name, dims, type, offset, nbytes} + self.tensor_data_start = 0 + self.alignment = GGUF_DEFAULT_ALIGNMENT + self._parse() + + def _read_str(self): + n = struct.unpack('= 2, f"Unsupported GGUF version {version}" + tensor_count, kv_count = struct.unpack('