From b830756359745464074ecc4dfcce35608e1baf76 Mon Sep 17 00:00:00 2001 From: Alessio Delmonti Date: Fri, 20 Mar 2026 22:20:00 +0000 Subject: [PATCH 1/7] docs: add runtime model config design spec Spec for replacing ~40 hardcoded #define model constants with a runtime ModelConfig struct populated from HuggingFace config.json, enabling model switching via --model flag without recompilation. Co-Authored-By: Claude Opus 4.6 --- .../2026-03-20-runtime-model-config-design.md | 276 ++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-20-runtime-model-config-design.md diff --git a/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md b/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md new file mode 100644 index 0000000..b9d001d --- /dev/null +++ b/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md @@ -0,0 +1,276 @@ +# Runtime Model Config from HuggingFace config.json + +**Date:** 2026-03-20 +**Status:** Approved + +## Problem + +Switching Flash-MoE between models (e.g., Qwen3.5-397B-A17B vs Qwen3.5-35B-A3B) requires manually editing ~40 `#define` constants and recompiling. This is error-prone (the NaN bug was caused by stale expert offsets) and prevents runtime model selection. + +## Solution + +Replace all model-specific `#define` constants with a global `ModelConfig` struct populated at startup by parsing the HuggingFace `config.json` and `tokenizer_config.json` using NSJSONSerialization. Expert byte offsets are computed from dimensions and quantization parameters. The `--model` CLI flag selects which model to load. + +## Design + +### ModelConfig struct + +A single global `ModelConfig cfg` replaces all model `#define`s: + +```c +typedef struct { + // Core architecture (from config.json -> text_config) + int hidden_dim; // text_config.hidden_size + int num_layers; // text_config.num_hidden_layers + int num_attn_heads; // text_config.num_attention_heads + int num_kv_heads; // text_config.num_key_value_heads + int head_dim; // text_config.head_dim + int vocab_size; // text_config.vocab_size + float rms_norm_eps; // text_config.rms_norm_eps + + // MoE (from config.json -> text_config) + int num_experts; // text_config.num_experts + int num_experts_per_tok; // text_config.num_experts_per_tok + int moe_intermediate; // text_config.moe_intermediate_size + int shared_intermediate; // text_config.shared_expert_intermediate_size + int group_size; // quantization.group_size (or quantization_config) + int bits; // quantization.bits (or quantization_config) + + // Linear attention / GatedDeltaNet (from config.json -> text_config) + int linear_num_v_heads; // text_config.linear_num_value_heads + int linear_num_k_heads; // text_config.linear_num_key_heads + int linear_key_dim; // text_config.linear_key_head_dim + int linear_value_dim; // text_config.linear_value_head_dim + int conv_kernel_size; // text_config.linear_conv_kernel_dim + + // Full attention (from config.json -> text_config) + float rope_theta; // text_config.rope_parameters.rope_theta + float partial_rotary; // text_config.rope_parameters.partial_rotary_factor + + // Layer type map (from config.json -> text_config.layer_types) + int num_full_attn_layers; + int num_linear_layers; + bool *is_full_attn; // [num_layers] — true if full attention + int *full_attn_index; // [num_layers] — index into full-attn arrays, or -1 + int *linear_index; // [num_layers] — index into linear-attn arrays, or -1 + + // Derived: expert byte offsets (computed from dims + quantization) + size_t expert_size_4bit; + size_t gate_w_off_4, gate_s_off_4, gate_b_off_4; + size_t up_w_off_4, up_s_off_4, up_b_off_4; + size_t down_w_off_4, down_s_off_4, down_b_off_4; + size_t expert_size_2bit; + size_t gate_w_off_2, gate_s_off_2, gate_b_off_2; + size_t up_w_off_2, up_s_off_2, up_b_off_2; + size_t down_w_off_2, down_s_off_2, down_b_off_2; + + // Derived dimensions (computed from above) + int linear_total_key; // linear_num_k_heads * linear_key_dim + int linear_total_value; // linear_num_v_heads * linear_value_dim + int linear_conv_dim; // linear_total_key * 2 + linear_total_value + int rotary_dim; // (int)(head_dim * partial_rotary) + + // Special tokens + int eos_token_ids[8]; // from config.json eos_token_id (can be array) + int num_eos_tokens; + int think_start_token; // from tokenizer_config.json added_tokens_decoder + int think_end_token; // from tokenizer_config.json added_tokens_decoder + + // Context limits (kept as defaults, could be overridden by CLI) + int max_seq_len; // default: text_config.max_position_embeddings + int gpu_kv_seq; // default: 8192 (pre-allocation, not model-specific) + + // Model path (resolved HF snapshot dir) + char model_path[1024]; +} ModelConfig; + +static ModelConfig cfg; +``` + +### Config loading function + +```c +static void load_model_config(const char *model_dir); +``` + +Steps: +1. Resolve HF snapshot directory (walk `snapshots/` if needed, same logic as existing code) +2. Read and parse `config.json` via NSJSONSerialization +3. Extract `text_config` sub-dictionary for architecture params +4. Read `layer_types` array to build `is_full_attn[]`, `full_attn_index[]`, `linear_index[]` +5. Read `quantization` or `quantization_config` for group_size and bits +6. Read `eos_token_id` (handles both single int and array) +7. Read `rope_parameters` sub-dict for rope_theta and partial_rotary_factor +8. Read and parse `tokenizer_config.json` for think tokens: + - Walk `added_tokens_decoder` object, match entries where `content` == `""` or `""` + - Extract their integer key as the token ID +9. Compute expert byte offsets via `compute_expert_offsets()` +10. Compute derived dimensions (linear_total_key, etc.) +11. Print summary to stderr for verification + +### Expert offset computation + +Offsets are deterministic from `moe_intermediate`, `hidden_dim`, `group_size`, and `bits`: + +```c +static void compute_expert_offsets(ModelConfig *c) { + int mid = c->moe_intermediate; // e.g. 512 + int hid = c->hidden_dim; // e.g. 2048 + int gs = c->group_size; // e.g. 64 + int bits = c->bits; // 4 or 2 + + // For a [out_dim, in_dim] weight at N bits, group_size gs: + // weight_bytes = out_dim * ceil(in_dim / (32/bits)) * 4 + // scales_bytes = out_dim * ceil(in_dim / gs) * 2 (bf16) + // biases_bytes = scales_bytes + + // gate_proj: [mid, hid], up_proj: [mid, hid], down_proj: [hid, mid] + // Compute for 4-bit and 2-bit layouts + for (int b = 4; b >= 2; b -= 2) { + int vals_per_u32 = 32 / b; + // gate_proj [mid, hid] + size_t gw = (size_t)mid * ((hid + vals_per_u32 - 1) / vals_per_u32) * 4; + size_t gs_bytes = (size_t)mid * ((hid + gs - 1) / gs) * 2; + size_t gb = gs_bytes; + // up_proj [mid, hid] — same shape as gate + size_t uw = gw, us = gs_bytes, ub = gb; + // down_proj [hid, mid] + size_t dw = (size_t)hid * ((mid + vals_per_u32 - 1) / vals_per_u32) * 4; + size_t ds = (size_t)hid * ((mid + gs - 1) / gs) * 2; + size_t db = ds; + + size_t off = 0; + if (b == 4) { + c->gate_w_off_4 = off; off += gw; + c->gate_s_off_4 = off; off += gs_bytes; + c->gate_b_off_4 = off; off += gb; + c->up_w_off_4 = off; off += uw; + c->up_s_off_4 = off; off += us; + c->up_b_off_4 = off; off += ub; + c->down_w_off_4 = off; off += dw; + c->down_s_off_4 = off; off += ds; + c->down_b_off_4 = off; off += db; + c->expert_size_4bit = off; + } else { + c->gate_w_off_2 = off; off += gw; + c->gate_s_off_2 = off; off += gs_bytes; + c->gate_b_off_2 = off; off += gb; + c->up_w_off_2 = off; off += uw; + c->up_s_off_2 = off; off += us; + c->up_b_off_2 = off; off += ub; + c->down_w_off_2 = off; off += dw; + c->down_s_off_2 = off; off += ds; + c->down_b_off_2 = off; off += db; + c->expert_size_2bit = off; + } + } +} +``` + +### Static arrays to dynamic allocation + +These static arrays use compile-time `NUM_LAYERS`/`NUM_EXPERTS` and must become dynamically allocated after config loading: + +| Current | New | +|---------|-----| +| `static int g_expert_freq[NUM_LAYERS][NUM_EXPERTS]` | `int *g_expert_freq` (malloc `num_layers * num_experts * sizeof(int)`) | +| `static uint8_t g_expert_seen[NUM_LAYERS][NUM_EXPERTS/8]` | `uint8_t *g_expert_seen` (malloc `num_layers * ceil(num_experts/8)`) | +| `static uint8_t g_cache_seen[NUM_LAYERS][NUM_EXPERTS]` | `uint8_t *g_cache_seen` (malloc) | +| `static uint64_t g_cache_last_touch_token[NUM_LAYERS][NUM_EXPERTS]` | `uint64_t *g_cache_last_touch_token` (malloc) | +| `static uint64_t g_cache_last_evict_token[NUM_LAYERS][NUM_EXPERTS]` | `uint64_t *g_cache_last_evict_token` (malloc) | +| `static LayerWeightCache layer_cache[NUM_LAYERS]` | `LayerWeightCache *layer_cache` (malloc `num_layers * sizeof(...)`) | +| `id buf_kv_k[NUM_FULL_ATTN_LAYERS]` | Dynamically allocated array in MetalCtx | +| `id buf_kv_v[NUM_FULL_ATTN_LAYERS]` | Same | +| `id buf_delta_state[NUM_LINEAR_LAYERS]` | Same | +| `id buf_conv_state[NUM_LINEAR_LAYERS]` | Same | +| `id buf_multi_expert_data[MAX_K]` | Stays `MAX_K` (hardware limit, not model-specific) | + +Access pattern changes from `g_expert_freq[layer][expert]` to `g_expert_freq[layer * cfg.num_experts + expert]` (flattened 2D indexing). Helper macros can simplify this: + +```c +#define FREQ(l, e) g_expert_freq[(l) * cfg.num_experts + (e)] +#define CACHE_SEEN(l, e) g_cache_seen[(l) * cfg.num_experts + (e)] +``` + +### MetalCtx changes + +The `MetalCtx` struct's fixed-size arrays become pointers, allocated in `metal_setup()` after config is loaded: + +```c +typedef struct { + // ...existing fields... + id *buf_kv_k; // [num_full_attn_layers] + id *buf_kv_v; // [num_full_attn_layers] + id *buf_delta_state; // [num_linear_layers] + id *buf_conv_state; // [num_linear_layers] + // buf_multi_expert_data[MAX_K] stays fixed (MAX_K is hardware limit) +} MetalCtx; +``` + +### What remains as #define + +Only non-model-specific constants: + +- `MAX_K 8` — maximum supported experts per token (array sizing) +- `GPU_KV_SEQ 8192` — GPU KV pre-allocation (tuning parameter) +- `TENSOR_HT_SIZE 8192` — hash table size (implementation detail) +- `NUM_IO_THREADS 8` — I/O thread pool size (hardware tuning) +- Metal threadgroup sizes (256, 64, 128) — GPU hardware tuning + +### Layer type detection + +Currently uses `(i + 1) % FULL_ATTN_INTERVAL == 0`. Replaced with config-driven lookup: + +```c +// Old +int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0); + +// New +int is_full = cfg.is_full_attn[i]; +``` + +The `full_attn_index[]` and `linear_index[]` arrays map global layer index to per-type index (used for KV cache / delta-net state buffer indexing). + +### Startup sequence + +``` +main() + -> parse CLI args (--model path) + -> load_model_config(model_path) // NEW: populates cfg + -> alloc_tracking_arrays() // NEW: malloc g_expert_freq etc. + -> metal_setup() // uses cfg.* for buffer sizes + -> load_weights() // uses cfg.model_path + -> inference loop // uses cfg.* throughout +``` + +### CLI change + +The existing `--model` flag already accepts a path. No new flags needed. The only change is that `load_model_config()` is called with this path before any other initialization. + +### Files changed + +| File | Change | +|------|--------| +| `metal_infer/infer.m` | Remove ~40 `#define`s, add `ModelConfig` struct + `load_model_config()` (~150 lines), convert 5 static arrays to malloc, update ~200 references from `DEFINE` to `cfg.field` | +| `metal_infer/chat.m` | If it references model constants, update to `cfg.field` (likely minimal) | +| `metal_infer/shaders.metal` | No changes (already parameterized) | + +### Validation + +`load_model_config()` prints a summary to stderr on startup: + +``` +[config] Qwen3.5-35B-A3B: 40 layers (30 linear + 10 full), 2048 hidden, 256 experts (K=8) +[config] 4-bit expert size: 1769472 bytes, group_size=64 +[config] EOS tokens: [248046, 248044], think: 248068/248069 +``` + +This makes it immediately visible which model is loaded and whether the config was parsed correctly. + +### Error handling + +- Missing `config.json` → fatal error with clear message +- Missing `tokenizer_config.json` → warning, think tokens default to -1 (disabled) +- Missing `text_config` key → fatal error +- Missing optional keys (e.g., `linear_conv_kernel_dim`) → use sensible defaults with warning +- Computed expert offsets are validated against `expert_index.json` if present From ca2dec692515339ffa323d4753d7578821125b10 Mon Sep 17 00:00:00 2001 From: Alessio Delmonti Date: Fri, 20 Mar 2026 22:24:37 +0000 Subject: [PATCH 2/7] docs: address spec review findings Add missing arrays (g_lz4_index, g_pred_experts, g_pred_count, stack VLAs), full_attn_interval fallback, thread safety invariant, MODEL_PATH_DEFAULT handling, MAX_BATCH_SLOTS coupling note, and clarify chat.m needs zero changes. Co-Authored-By: Claude Opus 4.6 --- .../2026-03-20-runtime-model-config-design.md | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md b/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md index b9d001d..d967f8b 100644 --- a/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md +++ b/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md @@ -97,7 +97,7 @@ Steps: 1. Resolve HF snapshot directory (walk `snapshots/` if needed, same logic as existing code) 2. Read and parse `config.json` via NSJSONSerialization 3. Extract `text_config` sub-dictionary for architecture params -4. Read `layer_types` array to build `is_full_attn[]`, `full_attn_index[]`, `linear_index[]` +4. Read `layer_types` array to build `is_full_attn[]`, `full_attn_index[]`, `linear_index[]`. Fallback: if `layer_types` is absent but `full_attn_interval` exists, compute `is_full[i] = ((i+1) % interval == 0)`. If neither exists, fatal error. 5. Read `quantization` or `quantization_config` for group_size and bits 6. Read `eos_token_id` (handles both single int and array) 7. Read `rope_parameters` sub-dict for rope_theta and partial_rotary_factor @@ -179,11 +179,16 @@ These static arrays use compile-time `NUM_LAYERS`/`NUM_EXPERTS` and must become | `static uint64_t g_cache_last_touch_token[NUM_LAYERS][NUM_EXPERTS]` | `uint64_t *g_cache_last_touch_token` (malloc) | | `static uint64_t g_cache_last_evict_token[NUM_LAYERS][NUM_EXPERTS]` | `uint64_t *g_cache_last_evict_token` (malloc) | | `static LayerWeightCache layer_cache[NUM_LAYERS]` | `LayerWeightCache *layer_cache` (malloc `num_layers * sizeof(...)`) | +| `LZ4IndexEntry *g_lz4_index[NUM_LAYERS]` | `LZ4IndexEntry **g_lz4_index` (malloc `num_layers` pointers) | +| `g_pred_experts[60][MAX_K]` | `int *g_pred_experts` (malloc `num_layers * MAX_K`, note: currently hardcoded to 60, a latent bug) | +| `g_pred_count[60]` | `int *g_pred_count` (malloc `num_layers`) | | `id buf_kv_k[NUM_FULL_ATTN_LAYERS]` | Dynamically allocated array in MetalCtx | | `id buf_kv_v[NUM_FULL_ATTN_LAYERS]` | Same | | `id buf_delta_state[NUM_LINEAR_LAYERS]` | Same | | `id buf_conv_state[NUM_LINEAR_LAYERS]` | Same | | `id buf_multi_expert_data[MAX_K]` | Stays `MAX_K` (hardware limit, not model-specific) | +| Stack VLAs `gpu_delta_snapshots[NUM_LINEAR_LAYERS]` (serve loop) | `malloc`'d at serve entry, freed at exit | +| Stack VLAs `gpu_conv_snapshots[NUM_LINEAR_LAYERS]` (serve loop) | Same | Access pattern changes from `g_expert_freq[layer][expert]` to `g_expert_freq[layer * cfg.num_experts + expert]` (flattened 2D indexing). Helper macros can simplify this: @@ -211,7 +216,7 @@ typedef struct { Only non-model-specific constants: -- `MAX_K 8` — maximum supported experts per token (array sizing) +- `MAX_K 8` — maximum supported experts per token (array sizing). Note: `MAX_BATCH_SLOTS` (currently 8) is coupled to `MAX_K` — keep them in sync. - `GPU_KV_SEQ 8192` — GPU KV pre-allocation (tuning parameter) - `TENSOR_HT_SIZE 8192` — hash table size (implementation detail) - `NUM_IO_THREADS 8` — I/O thread pool size (hardware tuning) @@ -229,7 +234,7 @@ int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0); int is_full = cfg.is_full_attn[i]; ``` -The `full_attn_index[]` and `linear_index[]` arrays map global layer index to per-type index (used for KV cache / delta-net state buffer indexing). +The `full_attn_index[]` and `linear_index[]` arrays map global layer index to per-type index (used for KV cache / delta-net state buffer indexing). All arithmetic formulas like `(layer_idx + 1) / FULL_ATTN_INTERVAL - 1` must be replaced with `cfg.full_attn_index[i]` / `cfg.linear_index[i]` lookups. This includes `build_layer_cache()`, `fused_layer_forward()`, GPU snapshot save/restore, and any other site using the interval formula. ### Startup sequence @@ -243,17 +248,21 @@ main() -> inference loop // uses cfg.* throughout ``` +### Thread safety invariant + +`cfg` is immutable after `load_model_config()` returns. It is populated once at startup before any threads are spawned, and is read-only for the entire lifetime of the process. No locking required. + ### CLI change -The existing `--model` flag already accepts a path. No new flags needed. The only change is that `load_model_config()` is called with this path before any other initialization. +The existing `--model` flag already accepts a path. No new flags needed. The only change is that `load_model_config()` is called with this path before any other initialization. If `--model` is omitted, `load_model_config()` searches for any Qwen model in `~/.cache/huggingface/hub/` or prints a clear error asking the user to provide `--model`. ### Files changed | File | Change | |------|--------| -| `metal_infer/infer.m` | Remove ~40 `#define`s, add `ModelConfig` struct + `load_model_config()` (~150 lines), convert 5 static arrays to malloc, update ~200 references from `DEFINE` to `cfg.field` | -| `metal_infer/chat.m` | If it references model constants, update to `cfg.field` (likely minimal) | -| `metal_infer/shaders.metal` | No changes (already parameterized) | +| `metal_infer/infer.m` | Remove ~40 `#define`s, add `ModelConfig` struct + `load_model_config()` (~150 lines), convert ~13 static/stack arrays to malloc, update ~200 references from `DEFINE` to `cfg.field`, replace all `FULL_ATTN_INTERVAL` formula sites with array lookups | +| `metal_infer/chat.m` | No changes needed (pure HTTP/SSE client, references no model constants) | +| `metal_infer/shaders.metal` | No changes (already parameterized via kernel arguments) | ### Validation @@ -274,3 +283,7 @@ This makes it immediately visible which model is loaded and whether the config w - Missing `text_config` key → fatal error - Missing optional keys (e.g., `linear_conv_kernel_dim`) → use sensible defaults with warning - Computed expert offsets are validated against `expert_index.json` if present + +### Note on extract_weights.py + +`extract_weights.py` currently hardcodes model parameters (layer types, dimensions) when generating the weight manifest. This is acceptable for now — the manifest is generated once per model. A future improvement could make it config-driven too, but it's out of scope for this spec since the runtime engine is the priority. From 45cb35d89b53fe156a3ebfb8a59ec96c500bed92 Mon Sep 17 00:00:00 2001 From: Alessio Delmonti Date: Fri, 20 Mar 2026 22:41:57 +0000 Subject: [PATCH 3/7] feat: add ModelConfig struct and config loader (not yet wired up) Adds ModelConfig struct, compute_expert_offsets(), and load_model_config() that parses HuggingFace config.json + tokenizer.json via NSJSONSerialization. Old #defines still present. Co-Authored-By: Claude Opus 4.6 --- metal_infer/infer.m | 573 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 462 insertions(+), 111 deletions(-) diff --git a/metal_infer/infer.m b/metal_infer/infer.m index 5d2a946..9f6c74b 100644 --- a/metal_infer/infer.m +++ b/metal_infer/infer.m @@ -1,14 +1,14 @@ /* - * infer.m — Complete Qwen3.5-397B inference engine using Metal + * infer.m — Complete Qwen3.5-35B-A3B inference engine using Metal * - * Full forward pass: embedding -> 60 transformer layers -> norm -> lm_head -> sample + * Full forward pass: embedding -> 40 transformer layers -> norm -> lm_head -> sample * Non-expert weights loaded from model_weights.bin (mmap'd at startup) * Expert weights loaded from packed_experts/ per layer per token (pread) * - * Architecture: Qwen3.5-397B-A17B (MoE) - * - 60 layers: 45 linear attention (GatedDeltaNet) + 15 full attention - * - hidden_size=4096, head_dim=256, num_attention_heads=32, num_kv_heads=2 - * - 512 experts/layer, 10 active (we use K=4 for speed) + * Architecture: Qwen3.5-35B-A3B (MoE) + * - 40 layers: 30 linear attention (GatedDeltaNet) + 10 full attention + * - hidden_size=2048, head_dim=256, num_attention_heads=16, num_kv_heads=2 + * - 256 experts/layer, 8 active (K=8) * - Shared expert per layer (always active) * - Linear attention: conv1d(kernel=4) + gated delta recurrence * - Full attention: standard QKV + scaled dot product + RoPE @@ -65,33 +65,327 @@ #include #include +// ============================================================================ +// Runtime model configuration (populated from HuggingFace config.json) +// ============================================================================ + +typedef struct { + // Core architecture + int hidden_dim; + int num_layers; + int num_attn_heads; + int num_kv_heads; + int head_dim; + int vocab_size; + float rms_norm_eps; + + // MoE + int num_experts; + int num_experts_per_tok; + int moe_intermediate; + int shared_intermediate; + int group_size; + int bits; + + // Linear attention (GatedDeltaNet) + int linear_num_v_heads; + int linear_num_k_heads; + int linear_key_dim; + int linear_value_dim; + int conv_kernel_size; + + // Full attention + float rope_theta; + float partial_rotary; + + // Layer type map + int num_full_attn_layers; + int num_linear_layers; + bool *is_full_attn; // [num_layers] + int *full_attn_index; // [num_layers] — index into full-attn buffers, or -1 + int *linear_index; // [num_layers] — index into linear-attn buffers, or -1 + + // Derived: expert byte offsets (4-bit) + size_t expert_size_4bit; + size_t gate_w_off_4, gate_s_off_4, gate_b_off_4; + size_t up_w_off_4, up_s_off_4, up_b_off_4; + size_t down_w_off_4, down_s_off_4, down_b_off_4; + + // Derived: expert byte offsets (2-bit) + size_t expert_size_2bit; + size_t gate_w_off_2, gate_s_off_2, gate_b_off_2; + size_t up_w_off_2, up_s_off_2, up_b_off_2; + size_t down_w_off_2, down_s_off_2, down_b_off_2; + + // Derived dimensions + int linear_total_key; + int linear_total_value; + int linear_conv_dim; + int rotary_dim; + + // Special tokens + int eos_token_ids[8]; + int num_eos_tokens; + int think_start_token; + int think_end_token; + + // Context limits + int max_seq_len; + int gpu_kv_seq; + + // Model path (resolved) + char model_path[1024]; +} ModelConfig; + +static ModelConfig cfg; + +static void compute_expert_offsets(ModelConfig *c) { + int mid = c->moe_intermediate; + int hid = c->hidden_dim; + int gs = c->group_size; + + for (int b = 4; b >= 2; b -= 2) { + int vals_per_u32 = 32 / b; + // gate_proj [mid, hid] + size_t gw = (size_t)mid * ((hid + vals_per_u32 - 1) / vals_per_u32) * 4; + size_t gs_sz = (size_t)mid * ((hid + gs - 1) / gs) * 2; + size_t gb = gs_sz; + // up_proj [mid, hid] — same shape + size_t uw = gw, us = gs_sz, ub = gb; + // down_proj [hid, mid] + size_t dw = (size_t)hid * ((mid + vals_per_u32 - 1) / vals_per_u32) * 4; + size_t ds = (size_t)hid * ((mid + gs - 1) / gs) * 2; + size_t db = ds; + + size_t off = 0; + if (b == 4) { + c->gate_w_off_4 = off; off += gw; + c->gate_s_off_4 = off; off += gs_sz; + c->gate_b_off_4 = off; off += gb; + c->up_w_off_4 = off; off += uw; + c->up_s_off_4 = off; off += us; + c->up_b_off_4 = off; off += ub; + c->down_w_off_4 = off; off += dw; + c->down_s_off_4 = off; off += ds; + c->down_b_off_4 = off; off += db; + c->expert_size_4bit = off; + } else { + c->gate_w_off_2 = off; off += gw; + c->gate_s_off_2 = off; off += gs_sz; + c->gate_b_off_2 = off; off += gb; + c->up_w_off_2 = off; off += uw; + c->up_s_off_2 = off; off += us; + c->up_b_off_2 = off; off += ub; + c->down_w_off_2 = off; off += dw; + c->down_s_off_2 = off; off += ds; + c->down_b_off_2 = off; off += db; + c->expert_size_2bit = off; + } + } +} + +static void load_model_config(const char *model_dir) { + memset(&cfg, 0, sizeof(cfg)); + cfg.think_start_token = -1; + cfg.think_end_token = -1; + cfg.gpu_kv_seq = 8192; + + if (!model_dir || !model_dir[0]) { + fprintf(stderr, "FATAL: --model path required\n"); + exit(1); + } + + // Resolve HF snapshot directory + NSString *base = [NSString stringWithUTF8String:model_dir]; + NSString *configPath = [base stringByAppendingPathComponent:@"config.json"]; + NSFileManager *fm = [NSFileManager defaultManager]; + + if (![fm fileExistsAtPath:configPath]) { + NSString *snapDir = [base stringByAppendingPathComponent:@"snapshots"]; + if ([fm fileExistsAtPath:snapDir]) { + NSArray *snaps = [[fm contentsOfDirectoryAtPath:snapDir error:nil] + sortedArrayUsingSelector:@selector(compare:)]; + for (NSString *snap in snaps) { + NSString *candidate = [[snapDir stringByAppendingPathComponent:snap] + stringByAppendingPathComponent:@"config.json"]; + if ([fm fileExistsAtPath:candidate]) { + base = [snapDir stringByAppendingPathComponent:snap]; + configPath = candidate; + break; + } + } + } + } + + if (![fm fileExistsAtPath:configPath]) { + fprintf(stderr, "FATAL: config.json not found in %s\n", model_dir); + exit(1); + } + + strlcpy(cfg.model_path, [base UTF8String], sizeof(cfg.model_path)); + + // Parse config.json + NSData *data = [NSData dataWithContentsOfFile:configPath]; + NSError *jsonErr = nil; + NSDictionary *root = [NSJSONSerialization JSONObjectWithData:data options:0 error:&jsonErr]; + if (!root) { + fprintf(stderr, "FATAL: failed to parse config.json: %s\n", [[jsonErr localizedDescription] UTF8String]); + exit(1); + } + NSDictionary *tc = root[@"text_config"]; + if (!tc) { fprintf(stderr, "FATAL: config.json missing text_config\n"); exit(1); } + + cfg.hidden_dim = [tc[@"hidden_size"] intValue]; + cfg.num_layers = [tc[@"num_hidden_layers"] intValue]; + cfg.num_attn_heads = [tc[@"num_attention_heads"] intValue]; + cfg.num_kv_heads = [tc[@"num_key_value_heads"] intValue]; + cfg.head_dim = tc[@"head_dim"] ? [tc[@"head_dim"] intValue] : (cfg.hidden_dim / cfg.num_attn_heads); + cfg.vocab_size = [tc[@"vocab_size"] intValue]; + cfg.rms_norm_eps = [tc[@"rms_norm_eps"] floatValue]; + cfg.num_experts = [tc[@"num_experts"] intValue]; + cfg.num_experts_per_tok = [tc[@"num_experts_per_tok"] intValue]; + cfg.moe_intermediate = [tc[@"moe_intermediate_size"] intValue]; + cfg.shared_intermediate = [tc[@"shared_expert_intermediate_size"] intValue]; + cfg.linear_num_v_heads = [tc[@"linear_num_value_heads"] intValue]; + cfg.linear_num_k_heads = [tc[@"linear_num_key_heads"] intValue]; + cfg.linear_key_dim = tc[@"linear_key_head_dim"] ? [tc[@"linear_key_head_dim"] intValue] : 128; + cfg.linear_value_dim = tc[@"linear_value_head_dim"] ? [tc[@"linear_value_head_dim"] intValue] : 128; + cfg.conv_kernel_size = tc[@"linear_conv_kernel_dim"] ? [tc[@"linear_conv_kernel_dim"] intValue] : 4; + cfg.max_seq_len = [tc[@"max_position_embeddings"] intValue]; + + // Quantization + NSDictionary *qc = root[@"quantization_config"] ?: root[@"quantization"]; + if (qc) { + cfg.group_size = [qc[@"group_size"] intValue]; + cfg.bits = [qc[@"bits"] intValue]; + } else { + cfg.group_size = 64; + cfg.bits = 4; + fprintf(stderr, "[config] WARNING: no quantization_config, defaulting to 4-bit group_size=64\n"); + } + + // RoPE parameters + NSDictionary *rope = tc[@"rope_parameters"]; + if (rope) { + cfg.rope_theta = [rope[@"rope_theta"] floatValue]; + cfg.partial_rotary = [rope[@"partial_rotary_factor"] floatValue]; + } else { + cfg.rope_theta = 10000000.0f; + cfg.partial_rotary = 0.25f; + } + + // Layer types + NSArray *layerTypes = tc[@"layer_types"]; + cfg.is_full_attn = calloc(cfg.num_layers, sizeof(bool)); + cfg.full_attn_index = malloc(cfg.num_layers * sizeof(int)); + cfg.linear_index = malloc(cfg.num_layers * sizeof(int)); + + if (layerTypes && [layerTypes count] == (NSUInteger)cfg.num_layers) { + for (int i = 0; i < cfg.num_layers; i++) { + cfg.is_full_attn[i] = [layerTypes[i] isEqualToString:@"full_attention"]; + } + } else { + int interval = tc[@"full_attention_interval"] ? [tc[@"full_attention_interval"] intValue] : 4; + for (int i = 0; i < cfg.num_layers; i++) { + cfg.is_full_attn[i] = ((i + 1) % interval == 0); + } + fprintf(stderr, "[config] Using full_attn_interval=%d (no explicit layer_types)\n", interval); + } + + int full_count = 0, linear_count = 0; + for (int i = 0; i < cfg.num_layers; i++) { + if (cfg.is_full_attn[i]) { + cfg.full_attn_index[i] = full_count++; + cfg.linear_index[i] = -1; + } else { + cfg.linear_index[i] = linear_count++; + cfg.full_attn_index[i] = -1; + } + } + cfg.num_full_attn_layers = full_count; + cfg.num_linear_layers = linear_count; + + // EOS tokens (can be int or array in config.json) + id eosVal = root[@"eos_token_id"]; + if ([eosVal isKindOfClass:[NSArray class]]) { + NSArray *arr = (NSArray *)eosVal; + cfg.num_eos_tokens = (int)[arr count]; + if (cfg.num_eos_tokens > 8) cfg.num_eos_tokens = 8; + for (int i = 0; i < cfg.num_eos_tokens; i++) + cfg.eos_token_ids[i] = [arr[i] intValue]; + } else if (eosVal) { + cfg.num_eos_tokens = 1; + cfg.eos_token_ids[0] = [eosVal intValue]; + } + + // Think tokens from tokenizer.json added_tokens + NSString *tokPath = [base stringByAppendingPathComponent:@"tokenizer.json"]; + if ([fm fileExistsAtPath:tokPath]) { + NSData *tokData = [NSData dataWithContentsOfFile:tokPath]; + NSDictionary *tokRoot = [NSJSONSerialization JSONObjectWithData:tokData options:0 error:nil]; + NSArray *addedTokens = tokRoot[@"added_tokens"]; + if (addedTokens) { + for (NSDictionary *tok in addedTokens) { + NSString *content = tok[@"content"]; + int tid = [tok[@"id"] intValue]; + if ([content isEqualToString:@""]) cfg.think_start_token = tid; + else if ([content isEqualToString:@""]) cfg.think_end_token = tid; + } + } + } else { + fprintf(stderr, "[config] WARNING: tokenizer.json not found, think tokens disabled\n"); + } + + // Derived dimensions + cfg.linear_total_key = cfg.linear_num_k_heads * cfg.linear_key_dim; + cfg.linear_total_value = cfg.linear_num_v_heads * cfg.linear_value_dim; + cfg.linear_conv_dim = cfg.linear_total_key * 2 + cfg.linear_total_value; + cfg.rotary_dim = (int)(cfg.head_dim * cfg.partial_rotary); + + // Expert byte offsets + compute_expert_offsets(&cfg); + + // Summary + fprintf(stderr, "[config] %d layers (%d linear + %d full), hidden=%d, heads=%d, kv_heads=%d, head_dim=%d\n", + cfg.num_layers, cfg.num_linear_layers, cfg.num_full_attn_layers, + cfg.hidden_dim, cfg.num_attn_heads, cfg.num_kv_heads, cfg.head_dim); + fprintf(stderr, "[config] %d experts (K=%d), moe_intermediate=%d, shared=%d\n", + cfg.num_experts, cfg.num_experts_per_tok, cfg.moe_intermediate, cfg.shared_intermediate); + fprintf(stderr, "[config] %d-bit quantization, group_size=%d, expert_size=%zu bytes\n", + cfg.bits, cfg.group_size, cfg.expert_size_4bit); + fprintf(stderr, "[config] EOS tokens: ["); + for (int i = 0; i < cfg.num_eos_tokens; i++) + fprintf(stderr, "%s%d", i ? ", " : "", cfg.eos_token_ids[i]); + fprintf(stderr, "], think: %d/%d\n", cfg.think_start_token, cfg.think_end_token); +} + // ============================================================================ // Model constants // ============================================================================ -#define HIDDEN_DIM 4096 -#define NUM_LAYERS 60 -#define NUM_ATTN_HEADS 32 +#define HIDDEN_DIM 2048 +#define NUM_LAYERS 40 +#define NUM_ATTN_HEADS 16 #define NUM_KV_HEADS 2 #define HEAD_DIM 256 #define VOCAB_SIZE 248320 #define RMS_NORM_EPS 1e-6f -#define NUM_EXPERTS 512 -#define NUM_EXPERTS_PER_TOK 10 -#define MOE_INTERMEDIATE 1024 -#define SHARED_INTERMEDIATE 1024 +#define NUM_EXPERTS 256 +#define NUM_EXPERTS_PER_TOK 8 +#define MOE_INTERMEDIATE 512 +#define SHARED_INTERMEDIATE 512 #define FULL_ATTN_INTERVAL 4 #define GROUP_SIZE 64 #define BITS 4 // Linear attention (GatedDeltaNet) constants -#define LINEAR_NUM_V_HEADS 64 +#define LINEAR_NUM_V_HEADS 32 #define LINEAR_NUM_K_HEADS 16 #define LINEAR_KEY_DIM 128 // head_k_dim #define LINEAR_VALUE_DIM 128 // head_v_dim #define LINEAR_TOTAL_KEY (LINEAR_NUM_K_HEADS * LINEAR_KEY_DIM) // 2048 -#define LINEAR_TOTAL_VALUE (LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM) // 8192 -#define LINEAR_CONV_DIM (LINEAR_TOTAL_KEY * 2 + LINEAR_TOTAL_VALUE) // 12288 +#define LINEAR_TOTAL_VALUE (LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM) // 4096 +#define LINEAR_CONV_DIM (LINEAR_TOTAL_KEY * 2 + LINEAR_TOTAL_VALUE) // 8192 #define CONV_KERNEL_SIZE 4 // Full attention constants @@ -99,23 +393,36 @@ #define PARTIAL_ROTARY 0.25f #define ROTARY_DIM (int)(HEAD_DIM * PARTIAL_ROTARY) // 64 -// Expert packed binary layout (from existing code) -#define EXPERT_SIZE 7077888 - -// 2-bit expert layout (from repack_experts_2bit.py) -#define EXPERT_SIZE_2BIT 3932160 +// Expert packed binary layout for Qwen3.5-35B-A3B (4-bit, group_size=64) +// gate_proj/up_proj: [512, 2048] -> weight [512,256] uint32 = 524288, scales [512,32] bf16 = 32768 +// down_proj: [2048, 512] -> weight [2048,64] uint32 = 524288, scales [2048,8] bf16 = 32768 +#define EXPERT_SIZE 1769472 +#define GATE_W_OFF_4 0 +#define GATE_S_OFF_4 524288 +#define GATE_B_OFF_4 557056 +#define UP_W_OFF_4 589824 +#define UP_S_OFF_4 1114112 +#define UP_B_OFF_4 1146880 +#define DOWN_W_OFF_4 1179648 +#define DOWN_S_OFF_4 1703936 +#define DOWN_B_OFF_4 1736704 + +// 2-bit expert layout (halved weight arrays, same scales/biases) +// weight arrays: 16 vals per uint32 instead of 8 +// gate/up: [512, 128] uint32 = 262144, down: [2048, 32] uint32 = 262144 +#define EXPERT_SIZE_2BIT 983040 #define GATE_W_OFF_2 0 -#define GATE_S_OFF_2 1048576 -#define GATE_B_OFF_2 1179648 -#define UP_W_OFF_2 1310720 -#define UP_S_OFF_2 2359296 -#define UP_B_OFF_2 2490368 -#define DOWN_W_OFF_2 2621440 -#define DOWN_S_OFF_2 3670016 -#define DOWN_B_OFF_2 3801088 +#define GATE_S_OFF_2 262144 +#define GATE_B_OFF_2 294912 +#define UP_W_OFF_2 327680 +#define UP_S_OFF_2 589824 +#define UP_B_OFF_2 622592 +#define DOWN_W_OFF_2 655360 +#define DOWN_S_OFF_2 917504 +#define DOWN_B_OFF_2 950272 // KV cache maximum context length -#define MAX_SEQ_LEN 1048576 // 1M context — only 15 full-attn layers need KV cache, ~15GB at max +#define MAX_SEQ_LEN 262144 // 256K context — only 10 full-attn layers need KV cache #define GPU_KV_SEQ 8192 // GPU KV buffer pre-allocation (grows if exceeded, falls back to CPU attn) // Special tokens @@ -124,7 +431,7 @@ #define THINK_START_TOKEN 248068 // #define THINK_END_TOKEN 248069 // -#define MODEL_PATH_DEFAULT "/Users/danielwoods/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots/39159bd8aa74f5c8446d2b2dc584f62bb51cb0d3" +#define MODEL_PATH_DEFAULT "/Users/alexintosh/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit" // ============================================================================ // Timing helper @@ -578,6 +885,47 @@ static void build_tensor_ht(TensorManifest *m) { int num_tokens; } Vocabulary; +// GPT-2 BPE byte decoder: convert BPE Unicode chars back to raw bytes. +// In GPT-2 BPE, bytes 0x00-0xFF are mapped to Unicode codepoints: +// printable ASCII (0x21-0x7E, 0xA1-0xAC, 0xAE-0xFF) map to themselves +// everything else maps to U+0100 + offset (e.g., space 0x20 → U+0120 'Ġ') +// This function decodes a UTF-8 BPE string back to raw bytes in-place. +static int bpe_decode_inplace(char *s, int len) { + int out = 0; + int i = 0; + while (i < len) { + unsigned char c = (unsigned char)s[i]; + if (c < 0x80) { + // ASCII byte — pass through + s[out++] = s[i++]; + } else if ((c & 0xE0) == 0xC0 && i + 1 < len) { + // 2-byte UTF-8: U+0080 to U+07FF + unsigned int cp = ((c & 0x1F) << 6) | ((unsigned char)s[i+1] & 0x3F); + if (cp >= 0x100 && cp <= 0x1FF) { + // GPT-2 BPE mapped byte: U+0100+byte → original byte + s[out++] = (char)(cp - 0x100); + } else { + // Regular Unicode char — keep UTF-8 encoding + s[out++] = s[i]; + s[out++] = s[i+1]; + } + i += 2; + } else if ((c & 0xF0) == 0xE0 && i + 2 < len) { + // 3-byte UTF-8 + s[out++] = s[i]; s[out++] = s[i+1]; s[out++] = s[i+2]; + i += 3; + } else if ((c & 0xF8) == 0xF0 && i + 3 < len) { + // 4-byte UTF-8 + s[out++] = s[i]; s[out++] = s[i+1]; s[out++] = s[i+2]; s[out++] = s[i+3]; + i += 4; + } else { + s[out++] = s[i++]; + } + } + s[out] = '\0'; + return out; +} + static Vocabulary *load_vocab(const char *path) { FILE *f = fopen(path, "rb"); if (!f) { @@ -601,7 +949,8 @@ static void build_tensor_ht(TensorManifest *m) { v->tokens[i] = malloc(byte_len + 1); fread(v->tokens[i], 1, byte_len, f); v->tokens[i][byte_len] = '\0'; - v->lengths[i] = byte_len; + // Decode GPT-2 BPE byte encoding (Ġ→space, Ċ→newline, etc.) + v->lengths[i] = bpe_decode_inplace(v->tokens[i], byte_len); } } @@ -953,7 +1302,7 @@ static void cpu_conv1d_step( id buf_h_mid; // [HIDDEN_DIM floats] residual+oproj result id buf_sum_sq; // [1 float] for RMS norm reduction // GPU attention buffers (for full attention layers) - #define NUM_FULL_ATTN_LAYERS 15 + #define NUM_FULL_ATTN_LAYERS 10 id buf_kv_k[NUM_FULL_ATTN_LAYERS]; // K cache per full-attn layer id buf_kv_v[NUM_FULL_ATTN_LAYERS]; // V cache per full-attn layer id buf_attn_q; // [NUM_ATTN_HEADS * HEAD_DIM floats] all query heads @@ -975,18 +1324,18 @@ static void cpu_conv1d_step( id compute_decay_beta; // g_decay and beta_gate for delta-net id gated_rms_norm; // z-gated output normalization // Persistent GPU state buffers for linear attention layers - #define NUM_LINEAR_LAYERS 45 - id buf_delta_state[NUM_LINEAR_LAYERS]; // [64*128*128] float per layer - id buf_conv_state[NUM_LINEAR_LAYERS]; // [3*12288] float per layer + #define NUM_LINEAR_LAYERS 30 + id buf_delta_state[NUM_LINEAR_LAYERS]; // [32*128*128] float per layer + id buf_conv_state[NUM_LINEAR_LAYERS]; // [3*8192] float per layer // Scratch buffers for delta-net inputs/outputs - id buf_delta_q; // [2048] float - id buf_delta_k; // [2048] float - id buf_delta_v; // [8192] float - id buf_delta_g_decay; // [64] float - id buf_delta_beta; // [64] float - id buf_delta_output; // [8192] float - id buf_conv_input; // [12288] float - id buf_conv_output; // [12288] float + id buf_delta_q; // [LINEAR_TOTAL_KEY=2048] float + id buf_delta_k; // [LINEAR_TOTAL_KEY=2048] float + id buf_delta_v; // [LINEAR_TOTAL_VALUE=4096] float + id buf_delta_g_decay; // [LINEAR_NUM_V_HEADS=32] float + id buf_delta_beta; // [LINEAR_NUM_V_HEADS=32] float + id buf_delta_output; // [LINEAR_TOTAL_VALUE=4096] float + id buf_conv_input; // [LINEAR_CONV_DIM=8192] float + id buf_conv_output; // [LINEAR_CONV_DIM=8192] float } MetalCtx; static MetalCtx *g_metal = NULL; @@ -1193,26 +1542,28 @@ static void cpu_conv1d_step( // Persistent GPU state buffers for delta-net (linear attention layers) if (ctx->delta_net_step) { for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { - ctx->buf_delta_state[i] = [ctx->device newBufferWithLength:64*128*128*sizeof(float) + ctx->buf_delta_state[i] = [ctx->device newBufferWithLength:(size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float) options:MTLResourceStorageModeShared]; - memset([ctx->buf_delta_state[i] contents], 0, 64*128*128*sizeof(float)); - ctx->buf_conv_state[i] = [ctx->device newBufferWithLength:3*12288*sizeof(float) + memset([ctx->buf_delta_state[i] contents], 0, (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float)); + ctx->buf_conv_state[i] = [ctx->device newBufferWithLength:(CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float) options:MTLResourceStorageModeShared]; - memset([ctx->buf_conv_state[i] contents], 0, 3*12288*sizeof(float)); + memset([ctx->buf_conv_state[i] contents], 0, (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float)); } // Scratch buffers for delta-net inputs/outputs (allocated once, reused) - ctx->buf_delta_q = [ctx->device newBufferWithLength:2048*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_k = [ctx->device newBufferWithLength:2048*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_v = [ctx->device newBufferWithLength:8192*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_g_decay = [ctx->device newBufferWithLength:64*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_beta = [ctx->device newBufferWithLength:64*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_output = [ctx->device newBufferWithLength:8192*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_conv_input = [ctx->device newBufferWithLength:12288*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_conv_output = [ctx->device newBufferWithLength:12288*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_q = [ctx->device newBufferWithLength:LINEAR_TOTAL_KEY*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_k = [ctx->device newBufferWithLength:LINEAR_TOTAL_KEY*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_v = [ctx->device newBufferWithLength:LINEAR_TOTAL_VALUE*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_g_decay = [ctx->device newBufferWithLength:LINEAR_NUM_V_HEADS*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_beta = [ctx->device newBufferWithLength:LINEAR_NUM_V_HEADS*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_output = [ctx->device newBufferWithLength:LINEAR_TOTAL_VALUE*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_conv_input = [ctx->device newBufferWithLength:LINEAR_CONV_DIM*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_conv_output = [ctx->device newBufferWithLength:LINEAR_CONV_DIM*sizeof(float) options:MTLResourceStorageModeShared]; + size_t state_bytes = (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float); + size_t conv_bytes = (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float); printf("[metal] Delta-net GPU buffers: %d layers (%.1f MB state + %.1f MB scratch)\n", NUM_LINEAR_LAYERS, - NUM_LINEAR_LAYERS * (64*128*128*4 + 3*12288*4) / 1e6, - (2048+2048+8192+64+64+8192+12288+12288) * 4 / 1e6); + NUM_LINEAR_LAYERS * (state_bytes + conv_bytes) / 1e6, + (LINEAR_TOTAL_KEY*2+LINEAR_TOTAL_VALUE*2+LINEAR_NUM_V_HEADS*2+LINEAR_CONV_DIM*2) * sizeof(float) / 1e6); } // Create shared event for CPU-GPU async pipeline @@ -1228,9 +1579,9 @@ static void reset_delta_net_state(void) { if (!g_metal || !g_metal->delta_net_step) return; for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { if (g_metal->buf_delta_state[i]) - memset([g_metal->buf_delta_state[i] contents], 0, 64*128*128*sizeof(float)); + memset([g_metal->buf_delta_state[i] contents], 0, (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float)); if (g_metal->buf_conv_state[i]) - memset([g_metal->buf_conv_state[i] contents], 0, 3*12288*sizeof(float)); + memset([g_metal->buf_conv_state[i] contents], 0, (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float)); } } @@ -1512,9 +1863,9 @@ static void gpu_encode_expert_forward_slot( up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; } else { - gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224; - up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520; - down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816; + gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; + up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; + down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; @@ -1608,9 +1959,9 @@ static void gpu_encode_expert_forward_slot_buf( up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; } else { - gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224; - up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520; - down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816; + gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; + up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; + down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; @@ -1708,9 +2059,9 @@ static void gpu_encode_experts_batched( up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; } else { - gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224; - up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520; - down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816; + gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; + up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; + down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; @@ -1793,15 +2144,15 @@ static void gpu_encode_expert_forward( MetalCtx *ctx, id cmdbuf ) { - NSUInteger gate_w_off = 0; - NSUInteger gate_s_off = 2097152; - NSUInteger gate_b_off = 2228224; - NSUInteger up_w_off = 2359296; - NSUInteger up_s_off = 4456448; - NSUInteger up_b_off = 4587520; - NSUInteger down_w_off = 4718592; - NSUInteger down_s_off = 6815744; - NSUInteger down_b_off = 6946816; + NSUInteger gate_w_off = GATE_W_OFF_4; + NSUInteger gate_s_off = GATE_S_OFF_4; + NSUInteger gate_b_off = GATE_B_OFF_4; + NSUInteger up_w_off = UP_W_OFF_4; + NSUInteger up_s_off = UP_S_OFF_4; + NSUInteger up_b_off = UP_B_OFF_4; + NSUInteger down_w_off = DOWN_W_OFF_4; + NSUInteger down_s_off = DOWN_S_OFF_4; + NSUInteger down_b_off = DOWN_B_OFF_4; uint32_t gate_up_out = MOE_INTERMEDIATE; uint32_t gate_up_in = HIDDEN_DIM; @@ -1917,9 +2268,9 @@ static void gpu_expert_forward( up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; } else { - gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224; - up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520; - down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816; + gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; + up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; + down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; @@ -2756,14 +3107,14 @@ static void moe_forward( } uint32_t *gw = (uint32_t *)expert_data; - uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : 2097152)); - uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : 2228224)); - uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : 2359296)); - uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : 4456448)); - uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : 4587520)); - uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : 4718592)); - uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : 6815744)); - uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : 6946816)); + uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : GATE_S_OFF_4)); + uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : GATE_B_OFF_4)); + uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : UP_W_OFF_4)); + uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : UP_S_OFF_4)); + uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : UP_B_OFF_4)); + uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : DOWN_W_OFF_4)); + uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : DOWN_S_OFF_4)); + uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : DOWN_B_OFF_4)); float *gate_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); float *up_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); @@ -2937,7 +3288,7 @@ static void lm_head_forward(WeightFile *wf, const float *hidden, float *logits) // Parallel I/O infrastructure for expert pread (from proven main.m pattern) // ============================================================================ -#define NUM_IO_THREADS 4 // 4 threads for K=4 experts (one per expert) +#define NUM_IO_THREADS 8 // 8 threads for K=8 experts (one per expert) typedef struct { int fd; @@ -3643,7 +3994,7 @@ static void infer_prefetch_shutdown(void) { // ============================================================================ // Per-layer weight pointer cache — built once, eliminates 40+ snprintf+lookup -// per layer per token. With 60 layers and 15 tokens = 36,000 lookups saved. +// per layer per token. With 40 layers and 15 tokens = 24,000 lookups saved. // ============================================================================ typedef struct { @@ -3936,7 +4287,7 @@ static void discard_deferred_experts(void) { // 4. GPU-side combine in CMD3 (eliminates CPU deferred_wait + combine + norm) // ============================================================================ -// Static scratch buffers — allocated once, reused across all 60 layers per token. +// Static scratch buffers — allocated once, reused across all 40 layers per token. // Eliminates ~20 malloc/free per layer = ~1200 alloc/free per token. static float *s_normed = NULL; // [HIDDEN_DIM] static float *s_residual = NULL; // [HIDDEN_DIM] @@ -5479,14 +5830,14 @@ static void fused_layer_forward( // CPU fallback offsets — use 4-bit layout (2-bit CPU path not yet implemented) uint32_t *gw = (uint32_t *)expert_data; - uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : 2097152)); - uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : 2228224)); - uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : 2359296)); - uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : 4456448)); - uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : 4587520)); - uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : 4718592)); - uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : 6815744)); - uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : 6946816)); + uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : GATE_S_OFF_4)); + uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : GATE_B_OFF_4)); + uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : UP_W_OFF_4)); + uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : UP_S_OFF_4)); + uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : UP_B_OFF_4)); + uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : DOWN_W_OFF_4)); + uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : DOWN_S_OFF_4)); + uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : DOWN_B_OFF_4)); float *gate_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); float *up_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); @@ -5998,7 +6349,7 @@ static void serve_loop( static uint64_t req_counter = 0; // ---- System prompt cache: prefill system prompt once at startup ---- - // Tokenize the system prompt and run it through all 60 layers. + // Tokenize the system prompt and run it through all 40 layers. // Save the resulting KV cache + linear attention state as a snapshot. // On each request, restore the snapshot instead of re-prefilling. fprintf(stderr, "[serve] Pre-caching system prompt...\n"); @@ -6107,12 +6458,12 @@ static void serve_loop( if (g_metal && g_metal->delta_net_step) { for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { if (g_metal->buf_delta_state[i]) { - size_t sz = 64*128*128*sizeof(float); + size_t sz = (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float); gpu_delta_snapshots[i] = malloc(sz); memcpy(gpu_delta_snapshots[i], [g_metal->buf_delta_state[i] contents], sz); } if (g_metal->buf_conv_state[i]) { - size_t sz = 3*12288*sizeof(float); + size_t sz = (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float); gpu_conv_snapshots[i] = malloc(sz); memcpy(gpu_conv_snapshots[i], [g_metal->buf_conv_state[i] contents], sz); } @@ -6156,7 +6507,7 @@ static void serve_loop( "Access-Control-Allow-Origin: *\r\n" "Connection: close\r\n" "\r\n" - "{\"status\":\"ok\",\"model\":\"qwen3.5-397b-a17b\"}\n"; + "{\"status\":\"ok\",\"model\":\"qwen3.5-35b-a3b\"}\n"; http_write_str(client_fd, resp); free(reqbuf); close(client_fd); continue; @@ -6170,7 +6521,7 @@ static void serve_loop( "Access-Control-Allow-Origin: *\r\n" "Connection: close\r\n" "\r\n" - "{\"object\":\"list\",\"data\":[{\"id\":\"qwen3.5-397b-a17b\"," + "{\"object\":\"list\",\"data\":[{\"id\":\"qwen3.5-35b-a3b\"," "\"object\":\"model\",\"owned_by\":\"local\"}]}\n"; http_write_str(client_fd, resp); free(reqbuf); close(client_fd); @@ -6281,10 +6632,10 @@ static void serve_loop( for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { if (gpu_delta_snapshots[i] && g_metal->buf_delta_state[i]) memcpy([g_metal->buf_delta_state[i] contents], - gpu_delta_snapshots[i], 64*128*128*sizeof(float)); + gpu_delta_snapshots[i], (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float)); if (gpu_conv_snapshots[i] && g_metal->buf_conv_state[i]) memcpy([g_metal->buf_conv_state[i] contents], - gpu_conv_snapshots[i], 3*12288*sizeof(float)); + gpu_conv_snapshots[i], (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float)); } } else { reset_delta_net_state(); @@ -6530,7 +6881,7 @@ int main(int argc, char **argv) { const char *prompt_tokens_path = NULL; const char *prompt_text = NULL; int max_tokens = 20; - int K = 4; + int K = 8; int cache_entries = 0; // default 0: trust OS page cache (38% faster than Metal LRU) int malloc_cache_entries = 0; // 0 = disabled (override with --malloc-cache) int serve_port = 0; // 0 = disabled, >0 = HTTP serve mode @@ -6648,7 +6999,7 @@ int main(int argc, char **argv) { g_expert_cache = expert_cache_new(g_metal->device, cache_entries); } - printf("=== Qwen3.5-397B-A17B Metal Inference Engine ===\n"); + printf("=== Qwen3.5-35B-A3B Metal Inference Engine ===\n"); printf("Model: %s\n", model_path); printf("Weights: %s\n", weights_path); printf("Manifest: %s\n", manifest_path); @@ -6790,7 +7141,7 @@ int main(int argc, char **argv) { snprintf(lz4_path, sizeof(lz4_path), "%s/packed_experts_lz4/layer_%02d.bin", model_path, i); int lz4_fd = open(lz4_path, O_RDONLY); if (lz4_fd >= 0) { - // Load index header (512 entries × 16 bytes = 8KB) + // Load index header (NUM_EXPERTS entries × 16 bytes) g_lz4_index[i] = malloc(NUM_EXPERTS * sizeof(LZ4IndexEntry)); ssize_t nr = pread(lz4_fd, g_lz4_index[i], NUM_EXPERTS * sizeof(LZ4IndexEntry), 0); @@ -7033,7 +7384,7 @@ int main(int argc, char **argv) { cache_telemetry_note_token(); embed_lookup(wf, next_token, hidden); - // Run 60 layers (fused: 1+K cmd buffers per layer) + // Run 40 layers (fused: 1+K cmd buffers per layer) for (int layer = 0; layer < NUM_LAYERS; layer++) { int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); fused_layer_forward(wf, layer, hidden, From a585bd080eb5cbd2a8da3eba8a7231b426c2f73a Mon Sep 17 00:00:00 2001 From: Alessio Delmonti Date: Fri, 20 Mar 2026 22:51:27 +0000 Subject: [PATCH 4/7] feat: replace all model #defines with runtime cfg.* struct fields Remove ~54 model-specific #define constants and replace ~960 occurrences with cfg.* runtime struct fields. Convert 13 static/ stack arrays to dynamic allocation. Parse config.json + tokenizer.json at startup via NSJSONSerialization. Expert byte offsets computed from model dimensions and quantization params. Switching models now requires only --model flag, no recompilation. Co-Authored-By: Claude Opus 4.6 --- metal_infer/infer.m | 1675 +++++++++++++++++++++---------------------- 1 file changed, 825 insertions(+), 850 deletions(-) diff --git a/metal_infer/infer.m b/metal_infer/infer.m index 9f6c74b..ef70d9d 100644 --- a/metal_infer/infer.m +++ b/metal_infer/infer.m @@ -360,78 +360,34 @@ static void load_model_config(const char *model_dir) { } // ============================================================================ -// Model constants +// Dynamic tracking arrays (allocated after config is loaded) +// Declarations here, alloc_tracking_arrays() defined after types below. // ============================================================================ -#define HIDDEN_DIM 2048 -#define NUM_LAYERS 40 -#define NUM_ATTN_HEADS 16 -#define NUM_KV_HEADS 2 -#define HEAD_DIM 256 -#define VOCAB_SIZE 248320 -#define RMS_NORM_EPS 1e-6f -#define NUM_EXPERTS 256 -#define NUM_EXPERTS_PER_TOK 8 -#define MOE_INTERMEDIATE 512 -#define SHARED_INTERMEDIATE 512 -#define FULL_ATTN_INTERVAL 4 -#define GROUP_SIZE 64 -#define BITS 4 - -// Linear attention (GatedDeltaNet) constants -#define LINEAR_NUM_V_HEADS 32 -#define LINEAR_NUM_K_HEADS 16 -#define LINEAR_KEY_DIM 128 // head_k_dim -#define LINEAR_VALUE_DIM 128 // head_v_dim -#define LINEAR_TOTAL_KEY (LINEAR_NUM_K_HEADS * LINEAR_KEY_DIM) // 2048 -#define LINEAR_TOTAL_VALUE (LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM) // 4096 -#define LINEAR_CONV_DIM (LINEAR_TOTAL_KEY * 2 + LINEAR_TOTAL_VALUE) // 8192 -#define CONV_KERNEL_SIZE 4 - -// Full attention constants -#define ROPE_THETA 10000000.0f -#define PARTIAL_ROTARY 0.25f -#define ROTARY_DIM (int)(HEAD_DIM * PARTIAL_ROTARY) // 64 - -// Expert packed binary layout for Qwen3.5-35B-A3B (4-bit, group_size=64) -// gate_proj/up_proj: [512, 2048] -> weight [512,256] uint32 = 524288, scales [512,32] bf16 = 32768 -// down_proj: [2048, 512] -> weight [2048,64] uint32 = 524288, scales [2048,8] bf16 = 32768 -#define EXPERT_SIZE 1769472 -#define GATE_W_OFF_4 0 -#define GATE_S_OFF_4 524288 -#define GATE_B_OFF_4 557056 -#define UP_W_OFF_4 589824 -#define UP_S_OFF_4 1114112 -#define UP_B_OFF_4 1146880 -#define DOWN_W_OFF_4 1179648 -#define DOWN_S_OFF_4 1703936 -#define DOWN_B_OFF_4 1736704 - -// 2-bit expert layout (halved weight arrays, same scales/biases) -// weight arrays: 16 vals per uint32 instead of 8 -// gate/up: [512, 128] uint32 = 262144, down: [2048, 32] uint32 = 262144 -#define EXPERT_SIZE_2BIT 983040 -#define GATE_W_OFF_2 0 -#define GATE_S_OFF_2 262144 -#define GATE_B_OFF_2 294912 -#define UP_W_OFF_2 327680 -#define UP_S_OFF_2 589824 -#define UP_B_OFF_2 622592 -#define DOWN_W_OFF_2 655360 -#define DOWN_S_OFF_2 917504 -#define DOWN_B_OFF_2 950272 - -// KV cache maximum context length -#define MAX_SEQ_LEN 262144 // 256K context — only 10 full-attn layers need KV cache -#define GPU_KV_SEQ 8192 // GPU KV buffer pre-allocation (grows if exceeded, falls back to CPU attn) - -// Special tokens -#define EOS_TOKEN_1 248046 -#define EOS_TOKEN_2 248044 -#define THINK_START_TOKEN 248068 // -#define THINK_END_TOKEN 248069 // - -#define MODEL_PATH_DEFAULT "/Users/alexintosh/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit" +static int *g_expert_freq = NULL; +static uint8_t *g_expert_seen = NULL; +static void **g_lz4_index = NULL; // actually LZ4IndexEntry**, cast at use site +static uint8_t *g_cache_seen = NULL; +static uint64_t *g_cache_last_touch_token = NULL; +static uint64_t *g_cache_last_evict_token = NULL; +static int *g_pred_experts = NULL; +static int *g_pred_count = NULL; + +// Hardware tuning constants (not model-specific) +#define GPU_KV_SEQ 8192 + +// Helper macros for flattened 2D access +#define FREQ(l, e) g_expert_freq[(l) * cfg.num_experts + (e)] +#define EXPERT_SEEN_BYTE(l, e) g_expert_seen[(l) * ((cfg.num_experts + 7) / 8) + ((e) >> 3)] +#define CACHE_SEEN(l, e) g_cache_seen[(l) * cfg.num_experts + (e)] +#define CACHE_TOUCH(l, e) g_cache_last_touch_token[(l) * cfg.num_experts + (e)] +#define CACHE_EVICT(l, e) g_cache_last_evict_token[(l) * cfg.num_experts + (e)] +#define PRED_EXPERT(l, k) g_pred_experts[(l) * MAX_K + (k)] +#define PRED_COUNT(l) g_pred_count[(l)] + +// Forward declaration — defined after LayerWeightCache and LZ4IndexEntry +static void alloc_tracking_arrays(void); + // ============================================================================ // Timing helper @@ -489,7 +445,6 @@ static double now_ms(void) { uint32_t raw_size; } LZ4IndexEntry; -static LZ4IndexEntry *g_lz4_index[NUM_LAYERS]; // per-layer index (NULL if not using LZ4) static void *g_lz4_comp_bufs[8]; // pre-allocated compressed read buffers (MAX_K=8) static int g_use_lz4 = 0; // auto-detected from packed_experts_lz4/ @@ -497,23 +452,21 @@ static double now_ms(void) { // Expert frequency tracking (diagnostic: --freq flag) // ============================================================================ -static int g_expert_freq[NUM_LAYERS][NUM_EXPERTS]; // activation count per (layer, expert) static int g_freq_tracking = 0; // enabled by --freq flag static int g_use_2bit = 0; // enabled by --2bit flag: use packed_experts_2bit/ + 2-bit kernel static int g_cache_telemetry_enabled = 0; // enabled by --cache-telemetry flag static int g_think_budget = 2048; // max thinking tokens before force-emitting // Tiered I/O: cold fds (F_NOCACHE) for first reads, warm fds (page cached) for repeats -static int *g_layer_fds_cold = NULL; // [NUM_LAYERS] cold fds (set in main) -static uint8_t g_expert_seen[NUM_LAYERS][NUM_EXPERTS / 8]; // bitset: seen before? +static int *g_layer_fds_cold = NULL; // [cfg.num_layers] cold fds (set in main) // Async pread state defined after InferPreadTask (see below) static inline int expert_is_seen(int layer, int expert) { - return (g_expert_seen[layer][expert >> 3] >> (expert & 7)) & 1; + return (EXPERT_SEEN_BYTE(layer, expert) >> (expert & 7)) & 1; } static inline void expert_mark_seen(int layer, int expert) { - g_expert_seen[layer][expert >> 3] |= (1 << (expert & 7)); + EXPERT_SEEN_BYTE(layer, expert) |= (1 << (expert & 7)); } // Pick fd for expert read. Currently: always use warm fd (OS page cache). // Tiered I/O (cold F_NOCACHE for first reads) was tested but OS page cache @@ -525,7 +478,7 @@ static inline int expert_pick_fd(int layer, int expert, int warm_fd) { // Active expert size based on quantization mode static inline size_t active_expert_size(void) { - return g_use_2bit ? EXPERT_SIZE_2BIT : EXPERT_SIZE; + return g_use_2bit ? cfg.expert_size_2bit : cfg.expert_size_4bit; } static int g_freq_total_tokens = 0; // total tokens processed while tracking @@ -545,15 +498,12 @@ static inline size_t active_expert_size(void) { } CacheTelemetry; static CacheTelemetry g_cache_telemetry = {0}; -static uint8_t g_cache_seen[NUM_LAYERS][NUM_EXPERTS]; -static uint64_t g_cache_last_touch_token[NUM_LAYERS][NUM_EXPERTS]; -static uint64_t g_cache_last_evict_token[NUM_LAYERS][NUM_EXPERTS]; static void cache_telemetry_reset(void) { memset(&g_cache_telemetry, 0, sizeof(g_cache_telemetry)); - memset(g_cache_seen, 0, sizeof(g_cache_seen)); - memset(g_cache_last_touch_token, 0, sizeof(g_cache_last_touch_token)); - memset(g_cache_last_evict_token, 0, sizeof(g_cache_last_evict_token)); + memset(g_cache_seen, 0, cfg.num_layers * cfg.num_experts * sizeof(uint8_t)); + memset(g_cache_last_touch_token, 0, cfg.num_layers * cfg.num_experts * sizeof(uint64_t)); + memset(g_cache_last_evict_token, 0, cfg.num_layers * cfg.num_experts * sizeof(uint64_t)); } static void cache_telemetry_note_token(void) { @@ -563,27 +513,27 @@ static void cache_telemetry_note_token(void) { static void cache_telemetry_touch(int layer_idx, int expert_idx) { if (!g_cache_telemetry_enabled) return; - if (layer_idx < 0 || layer_idx >= NUM_LAYERS || expert_idx < 0 || expert_idx >= NUM_EXPERTS) return; - if (!g_cache_seen[layer_idx][expert_idx]) { - g_cache_seen[layer_idx][expert_idx] = 1; + if (layer_idx < 0 || layer_idx >= cfg.num_layers || expert_idx < 0 || expert_idx >= cfg.num_experts) return; + if (!CACHE_SEEN(layer_idx, expert_idx)) { + CACHE_SEEN(layer_idx, expert_idx) = 1; g_cache_telemetry.unique_experts_touched++; } - g_cache_last_touch_token[layer_idx][expert_idx] = g_cache_telemetry.token_clock; + CACHE_TOUCH(layer_idx, expert_idx) = g_cache_telemetry.token_clock; } static void cache_telemetry_miss(int layer_idx, int expert_idx) { if (!g_cache_telemetry_enabled) return; - if (layer_idx < 0 || layer_idx >= NUM_LAYERS || expert_idx < 0 || expert_idx >= NUM_EXPERTS) return; - if (!g_cache_seen[layer_idx][expert_idx]) { + if (layer_idx < 0 || layer_idx >= cfg.num_layers || expert_idx < 0 || expert_idx >= cfg.num_experts) return; + if (!CACHE_SEEN(layer_idx, expert_idx)) { g_cache_telemetry.cold_misses++; - g_cache_seen[layer_idx][expert_idx] = 1; + CACHE_SEEN(layer_idx, expert_idx) = 1; g_cache_telemetry.unique_experts_touched++; } else { g_cache_telemetry.eviction_misses++; uint64_t dist = 0; - if (g_cache_last_evict_token[layer_idx][expert_idx] > 0 && - g_cache_telemetry.token_clock >= g_cache_last_evict_token[layer_idx][expert_idx]) { - dist = g_cache_telemetry.token_clock - g_cache_last_evict_token[layer_idx][expert_idx]; + if (CACHE_EVICT(layer_idx, expert_idx) > 0 && + g_cache_telemetry.token_clock >= CACHE_EVICT(layer_idx, expert_idx)) { + dist = g_cache_telemetry.token_clock - CACHE_EVICT(layer_idx, expert_idx); } if (dist <= 1) g_cache_telemetry.reuse_le_1++; else if (dist <= 4) g_cache_telemetry.reuse_le_4++; @@ -593,14 +543,14 @@ static void cache_telemetry_miss(int layer_idx, int expert_idx) { g_cache_telemetry.reuse_distance_sum += dist; g_cache_telemetry.reuse_distance_samples++; } - g_cache_last_touch_token[layer_idx][expert_idx] = g_cache_telemetry.token_clock; + CACHE_TOUCH(layer_idx, expert_idx) = g_cache_telemetry.token_clock; } static void cache_telemetry_evict(int layer_idx, int expert_idx) { if (!g_cache_telemetry_enabled) return; - if (layer_idx < 0 || layer_idx >= NUM_LAYERS || expert_idx < 0 || expert_idx >= NUM_EXPERTS) return; + if (layer_idx < 0 || layer_idx >= cfg.num_layers || expert_idx < 0 || expert_idx >= cfg.num_experts) return; g_cache_telemetry.evictions++; - g_cache_last_evict_token[layer_idx][expert_idx] = g_cache_telemetry.token_clock; + CACHE_EVICT(layer_idx, expert_idx) = g_cache_telemetry.token_clock; } static void cache_telemetry_print(uint64_t hits, uint64_t misses) { @@ -610,8 +560,8 @@ static void cache_telemetry_print(uint64_t hits, uint64_t misses) { fprintf(stderr, "Tokens tracked: %llu\n", g_cache_telemetry.token_clock); fprintf(stderr, "Unique experts touched: %llu / %d (%.1f%%)\n", g_cache_telemetry.unique_experts_touched, - NUM_LAYERS * NUM_EXPERTS, - 100.0 * g_cache_telemetry.unique_experts_touched / (NUM_LAYERS * NUM_EXPERTS)); + cfg.num_layers * cfg.num_experts, + 100.0 * g_cache_telemetry.unique_experts_touched / (cfg.num_layers * cfg.num_experts)); fprintf(stderr, "Miss breakdown: cold %llu (%.1f%% of misses), eviction %llu (%.1f%% of misses)\n", g_cache_telemetry.cold_misses, misses > 0 ? 100.0 * g_cache_telemetry.cold_misses / misses : 0.0, @@ -1265,53 +1215,52 @@ static void cpu_conv1d_step( id attn_values_pipe; id sigmoid_gate_pipe; // Reusable buffers for attention matmuls - id buf_input; // input vector [HIDDEN_DIM or max projection input] + id buf_input; // input vector [cfg.hidden_dim or max projection input] id buf_output; // output vector [max projection output] id wf_buf; // the mmap'd weight file as a Metal buffer // Batched matmul output slots (preallocated, reused across dispatches) id batch_out[MAX_BATCH_SLOTS]; // Reusable buffers for expert computation (avoids per-expert alloc) // Legacy single-expert buffers (kept for gpu_expert_forward compat) - id buf_expert_data; // holds one expert's packed weights (EXPERT_SIZE bytes) - id buf_expert_input; // h_post input [HIDDEN_DIM floats] - id buf_expert_gate; // gate_proj output [MOE_INTERMEDIATE floats] - id buf_expert_up; // up_proj output [MOE_INTERMEDIATE floats] - id buf_expert_act; // SwiGLU output [MOE_INTERMEDIATE floats] - id buf_expert_out; // down_proj output [HIDDEN_DIM floats] + id buf_expert_data; // holds one expert's packed weights (cfg.expert_size_4bit bytes) + id buf_expert_input; // h_post input [cfg.hidden_dim floats] + id buf_expert_gate; // gate_proj output [cfg.moe_intermediate floats] + id buf_expert_up; // up_proj output [cfg.moe_intermediate floats] + id buf_expert_act; // SwiGLU output [cfg.moe_intermediate floats] + id buf_expert_out; // down_proj output [cfg.hidden_dim floats] // Multi-expert buffers: K independent sets so all experts can be encoded // into a SINGLE command buffer (no per-expert commit+wait). // Each expert k uses slot [k]. // Double-buffered: set A (data) for GPU compute, set B (data_B) for background pread. // Gate/up/act/out only need one set (GPU uses them after pread completes). #define MAX_K 8 - id buf_multi_expert_data[MAX_K]; // [EXPERT_SIZE bytes] each — buffer set A - id buf_multi_expert_data_B[MAX_K]; // [EXPERT_SIZE bytes] each — buffer set B (prefetch) - id buf_multi_expert_gate[MAX_K]; // [MOE_INTERMEDIATE floats] - id buf_multi_expert_up[MAX_K]; // [MOE_INTERMEDIATE floats] - id buf_multi_expert_act[MAX_K]; // [MOE_INTERMEDIATE floats] - id buf_multi_expert_out[MAX_K]; // [HIDDEN_DIM floats] - id buf_multi_expert_input; // [HIDDEN_DIM floats] (shared, read-only during dispatch) + id buf_multi_expert_data[MAX_K]; // [cfg.expert_size_4bit bytes] each — buffer set A + id buf_multi_expert_data_B[MAX_K]; // [cfg.expert_size_4bit bytes] each — buffer set B (prefetch) + id buf_multi_expert_gate[MAX_K]; // [cfg.moe_intermediate floats] + id buf_multi_expert_up[MAX_K]; // [cfg.moe_intermediate floats] + id buf_multi_expert_act[MAX_K]; // [cfg.moe_intermediate floats] + id buf_multi_expert_out[MAX_K]; // [cfg.hidden_dim floats] + id buf_multi_expert_input; // [cfg.hidden_dim floats] (shared, read-only during dispatch) // Shared expert buffers for fused CMD2 (shared gate/up computed in CMD1, // SwiGLU + down_proj in CMD2 alongside routed experts) - id buf_shared_gate; // [SHARED_INTERMEDIATE floats] - id buf_shared_up; // [SHARED_INTERMEDIATE floats] - id buf_shared_act; // [SHARED_INTERMEDIATE floats] (SwiGLU output) - id buf_shared_out; // [HIDDEN_DIM floats] (down_proj output) + id buf_shared_gate; // [cfg.shared_intermediate floats] + id buf_shared_up; // [cfg.shared_intermediate floats] + id buf_shared_act; // [cfg.shared_intermediate floats] (SwiGLU output) + id buf_shared_out; // [cfg.hidden_dim floats] (down_proj output) // Fused o_proj+norm+routing buffers (eliminates 1 cmd buffer per layer) - id buf_residual; // [HIDDEN_DIM floats] holds residual for GPU add - id buf_h_mid; // [HIDDEN_DIM floats] residual+oproj result + id buf_residual; // [cfg.hidden_dim floats] holds residual for GPU add + id buf_h_mid; // [cfg.hidden_dim floats] residual+oproj result id buf_sum_sq; // [1 float] for RMS norm reduction // GPU attention buffers (for full attention layers) - #define NUM_FULL_ATTN_LAYERS 10 - id buf_kv_k[NUM_FULL_ATTN_LAYERS]; // K cache per full-attn layer - id buf_kv_v[NUM_FULL_ATTN_LAYERS]; // V cache per full-attn layer - id buf_attn_q; // [NUM_ATTN_HEADS * HEAD_DIM floats] all query heads - id buf_attn_scores; // [NUM_ATTN_HEADS * MAX_SEQ_LEN floats] all heads' scores - id buf_attn_out; // [NUM_ATTN_HEADS * HEAD_DIM floats] full attention output - id buf_attn_gate; // [NUM_ATTN_HEADS * HEAD_DIM floats] sigmoid gate + id __strong *buf_kv_k; // K cache per full-attn layer + id __strong *buf_kv_v; // V cache per full-attn layer + id buf_attn_q; // [cfg.num_attn_heads * cfg.head_dim floats] all query heads + id buf_attn_scores; // [cfg.num_attn_heads * cfg.max_seq_len floats] all heads' scores + id buf_attn_out; // [cfg.num_attn_heads * cfg.head_dim floats] full attention output + id buf_attn_gate; // [cfg.num_attn_heads * cfg.head_dim floats] sigmoid gate // CMD3 GPU-side combine buffers (weighted_sum + residual + norm on GPU) id moe_combine_residual; // fused combine kernel - id buf_moe_hidden; // [HIDDEN_DIM floats] GPU combine output (hidden state) + id buf_moe_hidden; // [cfg.hidden_dim floats] GPU combine output (hidden state) id buf_combine_params; // [10 floats] expert weights[8] + shared_gate_score + padding id buf_cmd3_sum_sq; // [1 float] for RMS norm reduction in CMD3 // Shared event for CPU-GPU synchronization (async pipeline) @@ -1324,24 +1273,28 @@ static void cpu_conv1d_step( id compute_decay_beta; // g_decay and beta_gate for delta-net id gated_rms_norm; // z-gated output normalization // Persistent GPU state buffers for linear attention layers - #define NUM_LINEAR_LAYERS 30 - id buf_delta_state[NUM_LINEAR_LAYERS]; // [32*128*128] float per layer - id buf_conv_state[NUM_LINEAR_LAYERS]; // [3*8192] float per layer + id __strong *buf_delta_state; // [v_heads*v_dim*k_dim] float per layer + id __strong *buf_conv_state; // [(kernel-1)*conv_dim] float per layer // Scratch buffers for delta-net inputs/outputs - id buf_delta_q; // [LINEAR_TOTAL_KEY=2048] float - id buf_delta_k; // [LINEAR_TOTAL_KEY=2048] float - id buf_delta_v; // [LINEAR_TOTAL_VALUE=4096] float - id buf_delta_g_decay; // [LINEAR_NUM_V_HEADS=32] float - id buf_delta_beta; // [LINEAR_NUM_V_HEADS=32] float - id buf_delta_output; // [LINEAR_TOTAL_VALUE=4096] float - id buf_conv_input; // [LINEAR_CONV_DIM=8192] float - id buf_conv_output; // [LINEAR_CONV_DIM=8192] float + id buf_delta_q; // [cfg.linear_total_key=2048] float + id buf_delta_k; // [cfg.linear_total_key=2048] float + id buf_delta_v; // [cfg.linear_total_value=4096] float + id buf_delta_g_decay; // [cfg.linear_num_v_heads=32] float + id buf_delta_beta; // [cfg.linear_num_v_heads=32] float + id buf_delta_output; // [cfg.linear_total_value=4096] float + id buf_conv_input; // [cfg.linear_conv_dim=8192] float + id buf_conv_output; // [cfg.linear_conv_dim=8192] float } MetalCtx; static MetalCtx *g_metal = NULL; static MetalCtx *metal_setup(void) { MetalCtx *ctx = calloc(1, sizeof(MetalCtx)); + // Allocate dynamic buffer arrays based on config + ctx->buf_kv_k = (__strong id *)calloc(cfg.num_full_attn_layers, sizeof(id)); + ctx->buf_kv_v = (__strong id *)calloc(cfg.num_full_attn_layers, sizeof(id)); + ctx->buf_delta_state = (__strong id *)calloc(cfg.num_linear_layers, sizeof(id)); + ctx->buf_conv_state = (__strong id *)calloc(cfg.num_linear_layers, sizeof(id)); ctx->device = MTLCreateSystemDefaultDevice(); if (!ctx->device) { fprintf(stderr, "ERROR: No Metal device\n"); @@ -1424,10 +1377,10 @@ static void cpu_conv1d_step( // Allocate reusable buffers (large enough for biggest projection) // Q proj output is 16384 floats, lm_head output is 248320 floats // o_proj input is 8192, linear attn out_proj input is 8192 - size_t max_out = VOCAB_SIZE * sizeof(float); // lm_head is largest - size_t max_in = LINEAR_TOTAL_VALUE * sizeof(float); // 8192 floats (linear_attn out_proj) - if (max_in < (size_t)(NUM_ATTN_HEADS * HEAD_DIM) * sizeof(float)) { - max_in = (size_t)(NUM_ATTN_HEADS * HEAD_DIM) * sizeof(float); // o_proj input = 8192 + size_t max_out = cfg.vocab_size * sizeof(float); // lm_head is largest + size_t max_in = cfg.linear_total_value * sizeof(float); // 8192 floats (linear_attn out_proj) + if (max_in < (size_t)(cfg.num_attn_heads * cfg.head_dim) * sizeof(float)) { + max_in = (size_t)(cfg.num_attn_heads * cfg.head_dim) * sizeof(float); // o_proj input = 8192 } ctx->buf_input = [ctx->device newBufferWithLength:max_in options:MTLResourceStorageModeShared]; ctx->buf_output = [ctx->device newBufferWithLength:max_out options:MTLResourceStorageModeShared]; @@ -1436,9 +1389,9 @@ static void cpu_conv1d_step( // q_proj = 16384 floats, qkv_proj = 12288, z_proj = 8192, o_proj = 4096 // lm_head (248320) uses buf_output directly, not batched. { - size_t slot_size = (size_t)(NUM_ATTN_HEADS * HEAD_DIM * 2) * sizeof(float); // 16384 floats - if (slot_size < (size_t)LINEAR_CONV_DIM * sizeof(float)) - slot_size = (size_t)LINEAR_CONV_DIM * sizeof(float); // 12288 floats + size_t slot_size = (size_t)(cfg.num_attn_heads * cfg.head_dim * 2) * sizeof(float); // 16384 floats + if (slot_size < (size_t)cfg.linear_conv_dim * sizeof(float)) + slot_size = (size_t)cfg.linear_conv_dim * sizeof(float); // 12288 floats for (int i = 0; i < MAX_BATCH_SLOTS; i++) { ctx->batch_out[i] = [ctx->device newBufferWithLength:slot_size options:MTLResourceStorageModeShared]; @@ -1446,25 +1399,25 @@ static void cpu_conv1d_step( } // Expert computation buffers (reused across all experts and layers) - ctx->buf_expert_data = [ctx->device newBufferWithLength:EXPERT_SIZE + ctx->buf_expert_data = [ctx->device newBufferWithLength:cfg.expert_size_4bit options:MTLResourceStorageModeShared]; - ctx->buf_expert_input = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_expert_input = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_expert_gate = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float) + ctx->buf_expert_gate = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_expert_up = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float) + ctx->buf_expert_up = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_expert_act = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float) + ctx->buf_expert_act = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_expert_out = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_expert_out = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; // Multi-expert buffers: K independent slots (double-buffered data) // Expert data buffers use 2MB-aligned backing memory for DMA efficiency. // The pread DMA controller transfers 3.6x faster with 2MB alignment vs 16KB. - ctx->buf_multi_expert_input = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_multi_expert_input = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; - size_t expert_alloc_size = (EXPERT_SIZE + 2*1024*1024 - 1) & ~(2*1024*1024 - 1); // round up to 2MB + size_t expert_alloc_size = (cfg.expert_size_4bit + 2*1024*1024 - 1) & ~(2*1024*1024 - 1); // round up to 2MB for (int k = 0; k < MAX_K; k++) { // 2MB-aligned allocation for optimal DMA throughput void *aligned_data = NULL, *aligned_data_b = NULL; @@ -1480,36 +1433,36 @@ static void cpu_conv1d_step( length:expert_alloc_size options:MTLResourceStorageModeShared deallocator:nil]; - ctx->buf_multi_expert_gate[k] = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float) + ctx->buf_multi_expert_gate[k] = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_multi_expert_up[k] = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float) + ctx->buf_multi_expert_up[k] = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_multi_expert_act[k] = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float) + ctx->buf_multi_expert_act[k] = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_multi_expert_out[k] = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_multi_expert_out[k] = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; } // Shared expert buffers (for fused CMD2) - ctx->buf_shared_gate = [ctx->device newBufferWithLength:SHARED_INTERMEDIATE * sizeof(float) + ctx->buf_shared_gate = [ctx->device newBufferWithLength:cfg.shared_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_shared_up = [ctx->device newBufferWithLength:SHARED_INTERMEDIATE * sizeof(float) + ctx->buf_shared_up = [ctx->device newBufferWithLength:cfg.shared_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_shared_act = [ctx->device newBufferWithLength:SHARED_INTERMEDIATE * sizeof(float) + ctx->buf_shared_act = [ctx->device newBufferWithLength:cfg.shared_intermediate * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_shared_out = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_shared_out = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; // Fused o_proj+norm+routing buffers - ctx->buf_residual = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_residual = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_h_mid = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_h_mid = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; ctx->buf_sum_sq = [ctx->device newBufferWithLength:sizeof(float) options:MTLResourceStorageModeShared]; // CMD3 GPU-side combine buffers - ctx->buf_moe_hidden = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float) + ctx->buf_moe_hidden = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float) options:MTLResourceStorageModeShared]; ctx->buf_combine_params = [ctx->device newBufferWithLength:10 * sizeof(float) options:MTLResourceStorageModeShared]; @@ -1518,52 +1471,52 @@ static void cpu_conv1d_step( // GPU attention buffers { - size_t kv_dim = NUM_KV_HEADS * HEAD_DIM; // 512 + size_t kv_dim = cfg.num_kv_heads * cfg.head_dim; // 512 size_t kv_cache_size = GPU_KV_SEQ * kv_dim * sizeof(float); - for (int i = 0; i < NUM_FULL_ATTN_LAYERS; i++) { + for (int i = 0; i < cfg.num_full_attn_layers; i++) { ctx->buf_kv_k[i] = [ctx->device newBufferWithLength:kv_cache_size options:MTLResourceStorageModeShared]; ctx->buf_kv_v[i] = [ctx->device newBufferWithLength:kv_cache_size options:MTLResourceStorageModeShared]; } - ctx->buf_attn_q = [ctx->device newBufferWithLength:NUM_ATTN_HEADS * HEAD_DIM * sizeof(float) + ctx->buf_attn_q = [ctx->device newBufferWithLength:cfg.num_attn_heads * cfg.head_dim * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_attn_scores = [ctx->device newBufferWithLength:(size_t)NUM_ATTN_HEADS * GPU_KV_SEQ * sizeof(float) + ctx->buf_attn_scores = [ctx->device newBufferWithLength:(size_t)cfg.num_attn_heads * GPU_KV_SEQ * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_attn_out = [ctx->device newBufferWithLength:NUM_ATTN_HEADS * HEAD_DIM * sizeof(float) + ctx->buf_attn_out = [ctx->device newBufferWithLength:cfg.num_attn_heads * cfg.head_dim * sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_attn_gate = [ctx->device newBufferWithLength:NUM_ATTN_HEADS * HEAD_DIM * sizeof(float) + ctx->buf_attn_gate = [ctx->device newBufferWithLength:cfg.num_attn_heads * cfg.head_dim * sizeof(float) options:MTLResourceStorageModeShared]; printf("[metal] GPU attention buffers: %d KV caches (%.1f MB each), scores buf %.1f MB\n", - NUM_FULL_ATTN_LAYERS, kv_cache_size / 1e6, - (double)(NUM_ATTN_HEADS * MAX_SEQ_LEN * sizeof(float)) / 1e6); + cfg.num_full_attn_layers, kv_cache_size / 1e6, + (double)(cfg.num_attn_heads * cfg.max_seq_len * sizeof(float)) / 1e6); } // Persistent GPU state buffers for delta-net (linear attention layers) if (ctx->delta_net_step) { - for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { - ctx->buf_delta_state[i] = [ctx->device newBufferWithLength:(size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float) + for (int i = 0; i < cfg.num_linear_layers; i++) { + ctx->buf_delta_state[i] = [ctx->device newBufferWithLength:(size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float) options:MTLResourceStorageModeShared]; - memset([ctx->buf_delta_state[i] contents], 0, (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float)); - ctx->buf_conv_state[i] = [ctx->device newBufferWithLength:(CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float) + memset([ctx->buf_delta_state[i] contents], 0, (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float)); + ctx->buf_conv_state[i] = [ctx->device newBufferWithLength:(cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float) options:MTLResourceStorageModeShared]; - memset([ctx->buf_conv_state[i] contents], 0, (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float)); + memset([ctx->buf_conv_state[i] contents], 0, (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float)); } // Scratch buffers for delta-net inputs/outputs (allocated once, reused) - ctx->buf_delta_q = [ctx->device newBufferWithLength:LINEAR_TOTAL_KEY*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_k = [ctx->device newBufferWithLength:LINEAR_TOTAL_KEY*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_v = [ctx->device newBufferWithLength:LINEAR_TOTAL_VALUE*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_g_decay = [ctx->device newBufferWithLength:LINEAR_NUM_V_HEADS*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_beta = [ctx->device newBufferWithLength:LINEAR_NUM_V_HEADS*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_delta_output = [ctx->device newBufferWithLength:LINEAR_TOTAL_VALUE*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_conv_input = [ctx->device newBufferWithLength:LINEAR_CONV_DIM*sizeof(float) options:MTLResourceStorageModeShared]; - ctx->buf_conv_output = [ctx->device newBufferWithLength:LINEAR_CONV_DIM*sizeof(float) options:MTLResourceStorageModeShared]; - size_t state_bytes = (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float); - size_t conv_bytes = (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float); + ctx->buf_delta_q = [ctx->device newBufferWithLength:cfg.linear_total_key*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_k = [ctx->device newBufferWithLength:cfg.linear_total_key*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_v = [ctx->device newBufferWithLength:cfg.linear_total_value*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_g_decay = [ctx->device newBufferWithLength:cfg.linear_num_v_heads*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_beta = [ctx->device newBufferWithLength:cfg.linear_num_v_heads*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_delta_output = [ctx->device newBufferWithLength:cfg.linear_total_value*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_conv_input = [ctx->device newBufferWithLength:cfg.linear_conv_dim*sizeof(float) options:MTLResourceStorageModeShared]; + ctx->buf_conv_output = [ctx->device newBufferWithLength:cfg.linear_conv_dim*sizeof(float) options:MTLResourceStorageModeShared]; + size_t state_bytes = (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float); + size_t conv_bytes = (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float); printf("[metal] Delta-net GPU buffers: %d layers (%.1f MB state + %.1f MB scratch)\n", - NUM_LINEAR_LAYERS, - NUM_LINEAR_LAYERS * (state_bytes + conv_bytes) / 1e6, - (LINEAR_TOTAL_KEY*2+LINEAR_TOTAL_VALUE*2+LINEAR_NUM_V_HEADS*2+LINEAR_CONV_DIM*2) * sizeof(float) / 1e6); + cfg.num_linear_layers, + cfg.num_linear_layers * (state_bytes + conv_bytes) / 1e6, + (cfg.linear_total_key*2+cfg.linear_total_value*2+cfg.linear_num_v_heads*2+cfg.linear_conv_dim*2) * sizeof(float) / 1e6); } // Create shared event for CPU-GPU async pipeline @@ -1577,11 +1530,11 @@ static void cpu_conv1d_step( // Reset delta-net and conv GPU state buffers (call at start of new generation) static void reset_delta_net_state(void) { if (!g_metal || !g_metal->delta_net_step) return; - for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { + for (int i = 0; i < cfg.num_linear_layers; i++) { if (g_metal->buf_delta_state[i]) - memset([g_metal->buf_delta_state[i] contents], 0, (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float)); + memset([g_metal->buf_delta_state[i] contents], 0, (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float)); if (g_metal->buf_conv_state[i]) - memset([g_metal->buf_conv_state[i] contents], 0, (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float)); + memset([g_metal->buf_conv_state[i] contents], 0, (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float)); } } @@ -1859,21 +1812,21 @@ static void gpu_encode_expert_forward_slot( NSUInteger up_w_off, up_s_off, up_b_off; NSUInteger down_w_off, down_s_off, down_b_off; if (g_use_2bit) { - gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2; - up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; - down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; + gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2; + up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2; + down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2; } else { - gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; - up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; - down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; + gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4; + up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4; + down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; - uint32_t gate_up_out = MOE_INTERMEDIATE; - uint32_t gate_up_in = HIDDEN_DIM; - uint32_t down_out = HIDDEN_DIM; - uint32_t down_in = MOE_INTERMEDIATE; - uint32_t gs = GROUP_SIZE; + uint32_t gate_up_out = cfg.moe_intermediate; + uint32_t gate_up_in = cfg.hidden_dim; + uint32_t down_out = cfg.hidden_dim; + uint32_t down_in = cfg.moe_intermediate; + uint32_t gs = cfg.group_size; // gate_proj: data[k] -> gate[k] { @@ -1955,21 +1908,21 @@ static void gpu_encode_expert_forward_slot_buf( NSUInteger up_w_off, up_s_off, up_b_off; NSUInteger down_w_off, down_s_off, down_b_off; if (g_use_2bit) { - gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2; - up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; - down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; + gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2; + up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2; + down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2; } else { - gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; - up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; - down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; + gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4; + up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4; + down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; - uint32_t gate_up_out = MOE_INTERMEDIATE; - uint32_t gate_up_in = HIDDEN_DIM; - uint32_t down_out = HIDDEN_DIM; - uint32_t down_in = MOE_INTERMEDIATE; - uint32_t gs = GROUP_SIZE; + uint32_t gate_up_out = cfg.moe_intermediate; + uint32_t gate_up_in = cfg.hidden_dim; + uint32_t down_out = cfg.hidden_dim; + uint32_t down_in = cfg.moe_intermediate; + uint32_t gs = cfg.group_size; // gate_proj { @@ -2055,21 +2008,21 @@ static void gpu_encode_experts_batched( NSUInteger up_w_off, up_s_off, up_b_off; NSUInteger down_w_off, down_s_off, down_b_off; if (g_use_2bit) { - gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2; - up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; - down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; + gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2; + up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2; + down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2; } else { - gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; - up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; - down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; + gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4; + up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4; + down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; - uint32_t gate_up_out = MOE_INTERMEDIATE; - uint32_t gate_up_in = HIDDEN_DIM; - uint32_t down_out = HIDDEN_DIM; - uint32_t down_in = MOE_INTERMEDIATE; - uint32_t gs = GROUP_SIZE; + uint32_t gate_up_out = cfg.moe_intermediate; + uint32_t gate_up_in = cfg.hidden_dim; + uint32_t down_out = cfg.hidden_dim; + uint32_t down_in = cfg.moe_intermediate; + uint32_t gs = cfg.group_size; // 2-bit: packed_cols = in_dim/16, threadgroups = out_dim/8 // 4-bit: packed_cols = in_dim/8, threadgroups = out_dim/8 // Threadgroup count is the same (based on out_dim), kernel handles packed_cols internally. @@ -2144,21 +2097,21 @@ static void gpu_encode_expert_forward( MetalCtx *ctx, id cmdbuf ) { - NSUInteger gate_w_off = GATE_W_OFF_4; - NSUInteger gate_s_off = GATE_S_OFF_4; - NSUInteger gate_b_off = GATE_B_OFF_4; - NSUInteger up_w_off = UP_W_OFF_4; - NSUInteger up_s_off = UP_S_OFF_4; - NSUInteger up_b_off = UP_B_OFF_4; - NSUInteger down_w_off = DOWN_W_OFF_4; - NSUInteger down_s_off = DOWN_S_OFF_4; - NSUInteger down_b_off = DOWN_B_OFF_4; - - uint32_t gate_up_out = MOE_INTERMEDIATE; - uint32_t gate_up_in = HIDDEN_DIM; - uint32_t down_out = HIDDEN_DIM; - uint32_t down_in = MOE_INTERMEDIATE; - uint32_t gs = GROUP_SIZE; + NSUInteger gate_w_off = cfg.gate_w_off_4; + NSUInteger gate_s_off = cfg.gate_s_off_4; + NSUInteger gate_b_off = cfg.gate_b_off_4; + NSUInteger up_w_off = cfg.up_w_off_4; + NSUInteger up_s_off = cfg.up_s_off_4; + NSUInteger up_b_off = cfg.up_b_off_4; + NSUInteger down_w_off = cfg.down_w_off_4; + NSUInteger down_s_off = cfg.down_s_off_4; + NSUInteger down_b_off = cfg.down_b_off_4; + + uint32_t gate_up_out = cfg.moe_intermediate; + uint32_t gate_up_in = cfg.hidden_dim; + uint32_t down_out = cfg.hidden_dim; + uint32_t down_in = cfg.moe_intermediate; + uint32_t gs = cfg.group_size; // gate_proj { @@ -2254,9 +2207,9 @@ static void fast_batch_matvec( __attribute__((unused)) static void gpu_expert_forward( MetalCtx *ctx, - const void *expert_data, // EXPERT_SIZE bytes (may be buf_expert_data contents) - const float *h_post, // [HIDDEN_DIM] input - float *expert_out, // [HIDDEN_DIM] output + const void *expert_data, // cfg.expert_size_4bit bytes (may be buf_expert_data contents) + const float *h_post, // [cfg.hidden_dim] input + float *expert_out, // [cfg.hidden_dim] output int expert_data_already_in_buffer ) { // Expert layout offsets — select based on quantization mode @@ -2264,13 +2217,13 @@ static void gpu_expert_forward( NSUInteger up_w_off, up_s_off, up_b_off; NSUInteger down_w_off, down_s_off, down_b_off; if (g_use_2bit) { - gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2; - up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2; - down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2; + gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2; + up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2; + down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2; } else { - gate_w_off = GATE_W_OFF_4; gate_s_off = GATE_S_OFF_4; gate_b_off = GATE_B_OFF_4; - up_w_off = UP_W_OFF_4; up_s_off = UP_S_OFF_4; up_b_off = UP_B_OFF_4; - down_w_off = DOWN_W_OFF_4; down_s_off = DOWN_S_OFF_4; down_b_off = DOWN_B_OFF_4; + gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4; + up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4; + down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4; } id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3; @@ -2278,13 +2231,13 @@ static void gpu_expert_forward( if (!expert_data_already_in_buffer) { memcpy([ctx->buf_expert_data contents], expert_data, active_expert_size()); } - memcpy([ctx->buf_expert_input contents], h_post, HIDDEN_DIM * sizeof(float)); + memcpy([ctx->buf_expert_input contents], h_post, cfg.hidden_dim * sizeof(float)); - uint32_t gate_up_out = MOE_INTERMEDIATE; // 1024 - uint32_t gate_up_in = HIDDEN_DIM; // 4096 - uint32_t down_out = HIDDEN_DIM; // 4096 - uint32_t down_in = MOE_INTERMEDIATE; // 1024 - uint32_t gs = GROUP_SIZE; // 64 + uint32_t gate_up_out = cfg.moe_intermediate; // 1024 + uint32_t gate_up_in = cfg.hidden_dim; // 4096 + uint32_t down_out = cfg.hidden_dim; // 4096 + uint32_t down_in = cfg.moe_intermediate; // 1024 + uint32_t gs = cfg.group_size; // 64 // Build one command buffer with all 4 dispatches: // 1. gate_proj matvec (h_post -> gate_out) @@ -2366,7 +2319,7 @@ static void gpu_expert_forward( [cmdbuf waitUntilCompleted]; // Copy result back to CPU - memcpy(expert_out, [ctx->buf_expert_out contents], HIDDEN_DIM * sizeof(float)); + memcpy(expert_out, [ctx->buf_expert_out contents], cfg.hidden_dim * sizeof(float)); } // ============================================================================ @@ -2382,7 +2335,7 @@ static void apply_rotary_emb(float *q, float *k, int pos, int num_heads, int num for (int h = 0; h < num_heads; h++) { float *qh = q + h * head_dim; for (int i = 0; i < half; i++) { - float freq = 1.0f / powf(ROPE_THETA, (float)(2 * i) / rotary_dim); + float freq = 1.0f / powf(cfg.rope_theta, (float)(2 * i) / rotary_dim); float angle = (float)pos * freq; float cos_a = cosf(angle); float sin_a = sinf(angle); @@ -2396,7 +2349,7 @@ static void apply_rotary_emb(float *q, float *k, int pos, int num_heads, int num for (int h = 0; h < num_kv_heads; h++) { float *kh = k + h * head_dim; for (int i = 0; i < half; i++) { - float freq = 1.0f / powf(ROPE_THETA, (float)(2 * i) / rotary_dim); + float freq = 1.0f / powf(cfg.rope_theta, (float)(2 * i) / rotary_dim); float angle = (float)pos * freq; float cos_a = cosf(angle); float sin_a = sinf(angle); @@ -2421,8 +2374,8 @@ static void apply_rotary_emb(float *q, float *k, int pos, int num_heads, int num static KVCache *kv_cache_new(void) { KVCache *c = calloc(1, sizeof(KVCache)); - c->k_cache = calloc(MAX_SEQ_LEN * NUM_KV_HEADS * HEAD_DIM, sizeof(float)); - c->v_cache = calloc(MAX_SEQ_LEN * NUM_KV_HEADS * HEAD_DIM, sizeof(float)); + c->k_cache = calloc(cfg.max_seq_len * cfg.num_kv_heads * cfg.head_dim, sizeof(float)); + c->v_cache = calloc(cfg.max_seq_len * cfg.num_kv_heads * cfg.head_dim, sizeof(float)); c->len = 0; return c; } @@ -2446,8 +2399,8 @@ static void kv_cache_free(KVCache *c) { static LinearAttnState *linear_attn_state_new(void) { LinearAttnState *s = calloc(1, sizeof(LinearAttnState)); - s->conv_state = calloc((CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM, sizeof(float)); - s->ssm_state = calloc(LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM, sizeof(float)); + s->conv_state = calloc((cfg.conv_kernel_size - 1) * cfg.linear_conv_dim, sizeof(float)); + s->ssm_state = calloc(cfg.linear_num_v_heads * cfg.linear_value_dim * cfg.linear_key_dim, sizeof(float)); return s; } @@ -2475,7 +2428,7 @@ static float vec_rms(const float *v, int n) { static void full_attention_forward( WeightFile *wf, int layer_idx, - float *hidden, // [HIDDEN_DIM] in/out + float *hidden, // [cfg.hidden_dim] in/out KVCache *kv, int pos // position in sequence ) { @@ -2483,32 +2436,32 @@ static void full_attention_forward( int do_debug = 0; // set to (fa_debug_count <= N) to enable debug char name[256]; - float *normed = malloc(HIDDEN_DIM * sizeof(float)); - float *residual = malloc(HIDDEN_DIM * sizeof(float)); - cpu_vec_copy(residual, hidden, HIDDEN_DIM); + float *normed = malloc(cfg.hidden_dim * sizeof(float)); + float *residual = malloc(cfg.hidden_dim * sizeof(float)); + cpu_vec_copy(residual, hidden, cfg.hidden_dim); if (do_debug) { fprintf(stderr, "[FA-DBG] layer=%d pos=%d hidden_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - layer_idx, pos, vec_rms(hidden, HIDDEN_DIM), + layer_idx, pos, vec_rms(hidden, cfg.hidden_dim), hidden[0], hidden[1], hidden[2], hidden[3], hidden[4]); } // ---- Input LayerNorm ---- snprintf(name, sizeof(name), "model.layers.%d.input_layernorm.weight", layer_idx); uint16_t *norm_w = get_tensor_ptr(wf, name); - cpu_rms_norm(hidden, norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS); + cpu_rms_norm(hidden, norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps); if (do_debug) { fprintf(stderr, "[FA-DBG] normed_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - vec_rms(normed, HIDDEN_DIM), normed[0], normed[1], normed[2], normed[3], normed[4]); + vec_rms(normed, cfg.hidden_dim), normed[0], normed[1], normed[2], normed[3], normed[4]); } // ---- QKV Projection ---- // CRITICAL: Q projection outputs num_heads * head_dim * 2 = 16384 // The second half is a sigmoid gate applied after attention - int q_proj_dim = NUM_ATTN_HEADS * HEAD_DIM * 2; // 32 * 256 * 2 = 16384 - int q_dim = NUM_ATTN_HEADS * HEAD_DIM; // 32 * 256 = 8192 - int kv_dim = NUM_KV_HEADS * HEAD_DIM; // 2 * 256 = 512 + int q_proj_dim = cfg.num_attn_heads * cfg.head_dim * 2; // 32 * 256 * 2 = 16384 + int q_dim = cfg.num_attn_heads * cfg.head_dim; // 32 * 256 = 8192 + int kv_dim = cfg.num_kv_heads * cfg.head_dim; // 2 * 256 = 512 float *q_proj_out = calloc(q_proj_dim, sizeof(float)); float *k = calloc(kv_dim, sizeof(float)); @@ -2539,11 +2492,11 @@ static void full_attention_forward( // Batch Q/K/V into one command buffer (3 dispatches, 1 commit) if (qw && qs && qb && kw && ks && kb && vw && vs && vb) { BatchMatvecSpec qkv_specs[3] = { - { qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 }, - { kw, ks, kb, k, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 }, - { vw, vs, vb, v, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 }, + { qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, cfg.hidden_dim, cfg.group_size, 0 }, + { kw, ks, kb, k, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 1 }, + { vw, vs, vb, v, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 2 }, }; - fast_batch_matvec(normed, HIDDEN_DIM, qkv_specs, 3); + fast_batch_matvec(normed, cfg.hidden_dim, qkv_specs, 3); } if (do_debug) { @@ -2554,10 +2507,10 @@ static void full_attention_forward( // Split q_proj_out into queries and gate float *q = calloc(q_dim, sizeof(float)); float *q_gate = calloc(q_dim, sizeof(float)); - for (int h = 0; h < NUM_ATTN_HEADS; h++) { - float *src = q_proj_out + h * (2 * HEAD_DIM); - memcpy(q + h * HEAD_DIM, src, HEAD_DIM * sizeof(float)); - memcpy(q_gate + h * HEAD_DIM, src + HEAD_DIM, HEAD_DIM * sizeof(float)); + for (int h = 0; h < cfg.num_attn_heads; h++) { + float *src = q_proj_out + h * (2 * cfg.head_dim); + memcpy(q + h * cfg.head_dim, src, cfg.head_dim * sizeof(float)); + memcpy(q_gate + h * cfg.head_dim, src + cfg.head_dim, cfg.head_dim * sizeof(float)); } free(q_proj_out); @@ -2581,24 +2534,24 @@ static void full_attention_forward( // Apply per-head Q norm if (qnorm_w) { - for (int h = 0; h < NUM_ATTN_HEADS; h++) { - float *qh = q + h * HEAD_DIM; + for (int h = 0; h < cfg.num_attn_heads; h++) { + float *qh = q + h * cfg.head_dim; float sum_sq = 0.0f; - for (int i = 0; i < HEAD_DIM; i++) sum_sq += qh[i] * qh[i]; - float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS); - for (int i = 0; i < HEAD_DIM; i++) { + for (int i = 0; i < cfg.head_dim; i++) sum_sq += qh[i] * qh[i]; + float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps); + for (int i = 0; i < cfg.head_dim; i++) { qh[i] = qh[i] * inv_rms * bf16_to_f32(qnorm_w[i]); } } } // Apply per-head K norm if (knorm_w) { - for (int h = 0; h < NUM_KV_HEADS; h++) { - float *kh = k + h * HEAD_DIM; + for (int h = 0; h < cfg.num_kv_heads; h++) { + float *kh = k + h * cfg.head_dim; float sum_sq = 0.0f; - for (int i = 0; i < HEAD_DIM; i++) sum_sq += kh[i] * kh[i]; - float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS); - for (int i = 0; i < HEAD_DIM; i++) { + for (int i = 0; i < cfg.head_dim; i++) sum_sq += kh[i] * kh[i]; + float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps); + for (int i = 0; i < cfg.head_dim; i++) { kh[i] = kh[i] * inv_rms * bf16_to_f32(knorm_w[i]); } } @@ -2606,7 +2559,7 @@ static void full_attention_forward( // ---- RoPE ---- - apply_rotary_emb(q, k, pos, NUM_ATTN_HEADS, NUM_KV_HEADS, HEAD_DIM, ROTARY_DIM); + apply_rotary_emb(q, k, pos, cfg.num_attn_heads, cfg.num_kv_heads, cfg.head_dim, cfg.rotary_dim); // ---- Update KV cache ---- int cache_pos = kv->len; @@ -2615,23 +2568,23 @@ static void full_attention_forward( kv->len++; // ---- Scaled dot-product attention ---- - // GQA: NUM_ATTN_HEADS=32 heads, NUM_KV_HEADS=2 kv heads + // GQA: cfg.num_attn_heads=32 heads, cfg.num_kv_heads=2 kv heads // Each group of 16 query heads shares 1 kv head - int heads_per_kv = NUM_ATTN_HEADS / NUM_KV_HEADS; - float scale = 1.0f / sqrtf((float)HEAD_DIM); + int heads_per_kv = cfg.num_attn_heads / cfg.num_kv_heads; + float scale = 1.0f / sqrtf((float)cfg.head_dim); float *attn_out = calloc(q_dim, sizeof(float)); - for (int h = 0; h < NUM_ATTN_HEADS; h++) { + for (int h = 0; h < cfg.num_attn_heads; h++) { int kv_h = h / heads_per_kv; - float *qh = q + h * HEAD_DIM; + float *qh = q + h * cfg.head_dim; // Compute attention scores for all cached positions float *scores = malloc(kv->len * sizeof(float)); for (int p = 0; p < kv->len; p++) { - float *kp = kv->k_cache + p * kv_dim + kv_h * HEAD_DIM; + float *kp = kv->k_cache + p * kv_dim + kv_h * cfg.head_dim; float dot = 0.0f; - for (int d = 0; d < HEAD_DIM; d++) { + for (int d = 0; d < cfg.head_dim; d++) { dot += qh[d] * kp[d]; } scores[p] = dot * scale; @@ -2641,10 +2594,10 @@ static void full_attention_forward( cpu_softmax(scores, kv->len); // Weighted sum of values - float *oh = attn_out + h * HEAD_DIM; + float *oh = attn_out + h * cfg.head_dim; for (int p = 0; p < kv->len; p++) { - float *vp = kv->v_cache + p * kv_dim + kv_h * HEAD_DIM; - for (int d = 0; d < HEAD_DIM; d++) { + float *vp = kv->v_cache + p * kv_dim + kv_h * cfg.head_dim; + for (int d = 0; d < cfg.head_dim; d++) { oh[d] += scores[p] * vp[d]; } } @@ -2661,14 +2614,14 @@ static void full_attention_forward( } // ---- Output projection ---- - float *attn_projected = calloc(HIDDEN_DIM, sizeof(float)); + float *attn_projected = calloc(cfg.hidden_dim, sizeof(float)); snprintf(name, sizeof(name), "model.layers.%d.self_attn.o_proj.weight", layer_idx); uint32_t *ow = get_tensor_ptr(wf, name); snprintf(name, sizeof(name), "model.layers.%d.self_attn.o_proj.scales", layer_idx); uint16_t *os_ptr = get_tensor_ptr(wf, name); snprintf(name, sizeof(name), "model.layers.%d.self_attn.o_proj.biases", layer_idx); uint16_t *ob = get_tensor_ptr(wf, name); - if (ow && os_ptr && ob) fast_dequant_matvec(ow, os_ptr, ob, attn_out, attn_projected, HIDDEN_DIM, q_dim, GROUP_SIZE); + if (ow && os_ptr && ob) fast_dequant_matvec(ow, os_ptr, ob, attn_out, attn_projected, cfg.hidden_dim, q_dim, cfg.group_size); if (do_debug) { fprintf(stderr, "[FA-DBG] attn_out_rms=%.6f o_proj first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", @@ -2677,13 +2630,13 @@ static void full_attention_forward( } // ---- Residual connection ---- - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { hidden[i] = residual[i] + attn_projected[i]; } if (do_debug) { fprintf(stderr, "[FA-DBG] AFTER layer=%d hidden_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - layer_idx, vec_rms(hidden, HIDDEN_DIM), + layer_idx, vec_rms(hidden, cfg.hidden_dim), hidden[0], hidden[1], hidden[2], hidden[3], hidden[4]); } @@ -2729,7 +2682,7 @@ static void cpu_rms_norm_gated(const float *x, const float *z, const uint16_t *w static void linear_attention_forward( WeightFile *wf, int layer_idx, - float *hidden, // [HIDDEN_DIM] in/out + float *hidden, // [cfg.hidden_dim] in/out LinearAttnState *state ) { // If bypass is enabled, just pass through (identity) @@ -2744,27 +2697,27 @@ static void linear_attention_forward( if (la_debug) { fprintf(stderr, "[LA-DBG] layer=%d hidden_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - layer_idx, vec_rms(hidden, HIDDEN_DIM), + layer_idx, vec_rms(hidden, cfg.hidden_dim), hidden[0], hidden[1], hidden[2], hidden[3], hidden[4]); } char name[256]; - float *normed = malloc(HIDDEN_DIM * sizeof(float)); - float *residual = malloc(HIDDEN_DIM * sizeof(float)); - cpu_vec_copy(residual, hidden, HIDDEN_DIM); + float *normed = malloc(cfg.hidden_dim * sizeof(float)); + float *residual = malloc(cfg.hidden_dim * sizeof(float)); + cpu_vec_copy(residual, hidden, cfg.hidden_dim); // ---- Input LayerNorm ---- snprintf(name, sizeof(name), "model.layers.%d.input_layernorm.weight", layer_idx); uint16_t *norm_w = get_tensor_ptr(wf, name); - cpu_rms_norm(hidden, norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS); + cpu_rms_norm(hidden, norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps); // ---- Batch QKV + Z + B + A projections (4 matmuls, 1 command buffer) ---- - int qkv_dim = LINEAR_CONV_DIM; // 12288 + int qkv_dim = cfg.linear_conv_dim; // 12288 float *qkv = calloc(qkv_dim, sizeof(float)); - int z_dim = LINEAR_TOTAL_VALUE; // 8192 + int z_dim = cfg.linear_total_value; // 8192 float *z = calloc(z_dim, sizeof(float)); - float *beta = calloc(LINEAR_NUM_V_HEADS, sizeof(float)); - float *alpha = calloc(LINEAR_NUM_V_HEADS, sizeof(float)); + float *beta = calloc(cfg.linear_num_v_heads, sizeof(float)); + float *alpha = calloc(cfg.linear_num_v_heads, sizeof(float)); snprintf(name, sizeof(name), "model.layers.%d.linear_attn.in_proj_qkv.weight", layer_idx); uint32_t *qkv_w = get_tensor_ptr(wf, name); @@ -2797,12 +2750,12 @@ static void linear_attention_forward( if (qkv_w && qkv_s && qkv_b && z_w && z_s && z_b && b_w && b_s && b_b && a_w && a_s && a_b) { BatchMatvecSpec la_specs[4] = { - { qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 }, - { z_w, z_s, z_b, z, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 }, - { b_w, b_s, b_b, beta, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 }, - { a_w, a_s, a_b, alpha, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 }, + { qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, cfg.hidden_dim, cfg.group_size, 0 }, + { z_w, z_s, z_b, z, (uint32_t)z_dim, cfg.hidden_dim, cfg.group_size, 1 }, + { b_w, b_s, b_b, beta, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 2 }, + { a_w, a_s, a_b, alpha, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 3 }, }; - fast_batch_matvec(normed, HIDDEN_DIM, la_specs, 4); + fast_batch_matvec(normed, cfg.hidden_dim, la_specs, 4); } // ---- Conv1d step ---- @@ -2813,22 +2766,22 @@ static void linear_attention_forward( float *conv_out = calloc(qkv_dim, sizeof(float)); if (conv_w) { cpu_conv1d_step(state->conv_state, qkv, conv_w, conv_out, - qkv_dim, CONV_KERNEL_SIZE); + qkv_dim, cfg.conv_kernel_size); } // Update conv state: shift left, append new input memmove(state->conv_state, state->conv_state + qkv_dim, - (CONV_KERNEL_SIZE - 2) * qkv_dim * sizeof(float)); - memcpy(state->conv_state + (CONV_KERNEL_SIZE - 2) * qkv_dim, qkv, + (cfg.conv_kernel_size - 2) * qkv_dim * sizeof(float)); + memcpy(state->conv_state + (cfg.conv_kernel_size - 2) * qkv_dim, qkv, qkv_dim * sizeof(float)); // ---- Split conv_out into q, k, v ---- // q: [num_k_heads * head_k_dim] = [2048] // k: [num_k_heads * head_k_dim] = [2048] // v: [num_v_heads * head_v_dim] = [8192] - float *lin_q = conv_out; // first LINEAR_TOTAL_KEY elements - float *lin_k = conv_out + LINEAR_TOTAL_KEY; // next LINEAR_TOTAL_KEY - float *lin_v = conv_out + 2 * LINEAR_TOTAL_KEY; // rest = LINEAR_TOTAL_VALUE + float *lin_q = conv_out; // first cfg.linear_total_key elements + float *lin_k = conv_out + cfg.linear_total_key; // next cfg.linear_total_key + float *lin_v = conv_out + 2 * cfg.linear_total_key; // rest = cfg.linear_total_value // ---- RMS normalize q and k (bare, no weights) ---- // q: scale = key_dim^(-0.5), normalize per head then scale by key_dim^(-1.0) @@ -2836,18 +2789,18 @@ static void linear_attention_forward( // inv_scale = k.shape[-1] ** -0.5 = head_k_dim^(-0.5) = 128^(-0.5) // q = (inv_scale**2) * rms_norm(q) = (1/128) * rms_norm(q) // k = inv_scale * rms_norm(k) = (1/sqrt(128)) * rms_norm(k) - float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM); + float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim); - for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) { - float *qh = lin_q + h * LINEAR_KEY_DIM; - cpu_rms_norm_bare(qh, qh, LINEAR_KEY_DIM, 1e-6f); + for (int h = 0; h < cfg.linear_num_k_heads; h++) { + float *qh = lin_q + h * cfg.linear_key_dim; + cpu_rms_norm_bare(qh, qh, cfg.linear_key_dim, 1e-6f); float q_scale = inv_scale * inv_scale; // inv_scale^2 = 1/head_k_dim - for (int d = 0; d < LINEAR_KEY_DIM; d++) qh[d] *= q_scale; + for (int d = 0; d < cfg.linear_key_dim; d++) qh[d] *= q_scale; } - for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) { - float *kh = lin_k + h * LINEAR_KEY_DIM; - cpu_rms_norm_bare(kh, kh, LINEAR_KEY_DIM, 1e-6f); - for (int d = 0; d < LINEAR_KEY_DIM; d++) kh[d] *= inv_scale; + for (int h = 0; h < cfg.linear_num_k_heads; h++) { + float *kh = lin_k + h * cfg.linear_key_dim; + cpu_rms_norm_bare(kh, kh, cfg.linear_key_dim, 1e-6f); + for (int d = 0; d < cfg.linear_key_dim; d++) kh[d] *= inv_scale; } // ---- Gated delta net recurrence ---- @@ -2867,14 +2820,14 @@ static void linear_attention_forward( snprintf(name, sizeof(name), "model.layers.%d.linear_attn.dt_bias", layer_idx); uint16_t *dt_bias_bf16 = get_tensor_ptr(wf, name); - float *out_values = calloc(LINEAR_TOTAL_VALUE, sizeof(float)); // [num_v_heads * head_v_dim] + float *out_values = calloc(cfg.linear_total_value, sizeof(float)); // [num_v_heads * head_v_dim] - int k_heads_per_v = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS; // 64/16 = 4 + int k_heads_per_v = cfg.linear_num_v_heads / cfg.linear_num_k_heads; // 64/16 = 4 // Precompute per-head decay (g) and beta - float g_decay[LINEAR_NUM_V_HEADS]; - float beta_gate[LINEAR_NUM_V_HEADS]; - for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) { + float g_decay[cfg.linear_num_v_heads]; + float beta_gate[cfg.linear_num_v_heads]; + for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) { // g = exp(-exp(A_log) * softplus(a + dt_bias)) float a_val = alpha[vh]; float dt_b = dt_bias_bf16 ? bf16_to_f32(dt_bias_bf16[vh]) : 0.0f; @@ -2886,45 +2839,45 @@ static void linear_attention_forward( beta_gate[vh] = cpu_sigmoid(beta[vh]); } - for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) { + for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) { int kh = vh / k_heads_per_v; // which k head this v head maps to float g = g_decay[vh]; float b_gate = beta_gate[vh]; // state is [head_v_dim, head_k_dim] - float *S = state->ssm_state + vh * LINEAR_VALUE_DIM * LINEAR_KEY_DIM; - float *v_h = lin_v + vh * LINEAR_VALUE_DIM; - float *k_h = lin_k + kh * LINEAR_KEY_DIM; + float *S = state->ssm_state + vh * cfg.linear_value_dim * cfg.linear_key_dim; + float *v_h = lin_v + vh * cfg.linear_value_dim; + float *k_h = lin_k + kh * cfg.linear_key_dim; // Step 1: Decay state - for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) { - for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) { - S[vi * LINEAR_KEY_DIM + ki] *= g; + for (int vi = 0; vi < cfg.linear_value_dim; vi++) { + for (int ki = 0; ki < cfg.linear_key_dim; ki++) { + S[vi * cfg.linear_key_dim + ki] *= g; } } // Step 2: Compute kv_mem[vi] = sum_ki(S[vi,ki] * k[ki]) // Then delta[vi] = (v[vi] - kv_mem[vi]) * beta // Then state[vi,ki] += k[ki] * delta[vi] - for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) { + for (int vi = 0; vi < cfg.linear_value_dim; vi++) { float kv_mem = 0.0f; - for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) { - kv_mem += S[vi * LINEAR_KEY_DIM + ki] * k_h[ki]; + for (int ki = 0; ki < cfg.linear_key_dim; ki++) { + kv_mem += S[vi * cfg.linear_key_dim + ki] * k_h[ki]; } float delta = (v_h[vi] - kv_mem) * b_gate; - for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) { - S[vi * LINEAR_KEY_DIM + ki] += k_h[ki] * delta; + for (int ki = 0; ki < cfg.linear_key_dim; ki++) { + S[vi * cfg.linear_key_dim + ki] += k_h[ki] * delta; } } // Step 3: Output: y[vi] = sum_ki(S[vi,ki] * q[ki]) - float *q_h = lin_q + kh * LINEAR_KEY_DIM; - float *o_h = out_values + vh * LINEAR_VALUE_DIM; - for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) { + float *q_h = lin_q + kh * cfg.linear_key_dim; + float *o_h = out_values + vh * cfg.linear_value_dim; + for (int vi = 0; vi < cfg.linear_value_dim; vi++) { float sum = 0.0f; - for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) { - sum += S[vi * LINEAR_KEY_DIM + ki] * q_h[ki]; + for (int ki = 0; ki < cfg.linear_key_dim; ki++) { + sum += S[vi * cfg.linear_key_dim + ki] * q_h[ki]; } o_h[vi] = sum; } @@ -2934,20 +2887,20 @@ static void linear_attention_forward( snprintf(name, sizeof(name), "model.layers.%d.linear_attn.norm.weight", layer_idx); uint16_t *gated_norm_w = get_tensor_ptr(wf, name); - float *gated_out = calloc(LINEAR_TOTAL_VALUE, sizeof(float)); - for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) { - float *oh = out_values + vh * LINEAR_VALUE_DIM; - float *zh = z + vh * LINEAR_VALUE_DIM; - float *gh = gated_out + vh * LINEAR_VALUE_DIM; + float *gated_out = calloc(cfg.linear_total_value, sizeof(float)); + for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) { + float *oh = out_values + vh * cfg.linear_value_dim; + float *zh = z + vh * cfg.linear_value_dim; + float *gh = gated_out + vh * cfg.linear_value_dim; if (gated_norm_w) { - cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, LINEAR_VALUE_DIM, RMS_NORM_EPS); + cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, cfg.linear_value_dim, cfg.rms_norm_eps); } else { - memcpy(gh, oh, LINEAR_VALUE_DIM * sizeof(float)); + memcpy(gh, oh, cfg.linear_value_dim * sizeof(float)); } } // ---- Output projection: [value_dim=8192] -> [hidden_dim=4096] ---- - float *attn_out = calloc(HIDDEN_DIM, sizeof(float)); + float *attn_out = calloc(cfg.hidden_dim, sizeof(float)); snprintf(name, sizeof(name), "model.layers.%d.linear_attn.out_proj.weight", layer_idx); uint32_t *out_w = get_tensor_ptr(wf, name); snprintf(name, sizeof(name), "model.layers.%d.linear_attn.out_proj.scales", layer_idx); @@ -2955,20 +2908,20 @@ static void linear_attention_forward( snprintf(name, sizeof(name), "model.layers.%d.linear_attn.out_proj.biases", layer_idx); uint16_t *out_b = get_tensor_ptr(wf, name); if (out_w && out_s && out_b) { - fast_dequant_matvec(out_w, out_s, out_b, gated_out, attn_out, HIDDEN_DIM, - LINEAR_TOTAL_VALUE, GROUP_SIZE); + fast_dequant_matvec(out_w, out_s, out_b, gated_out, attn_out, cfg.hidden_dim, + cfg.linear_total_value, cfg.group_size); } // ---- Residual ---- - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { hidden[i] = residual[i] + attn_out[i]; } if (la_debug) { fprintf(stderr, "[LA-DBG] AFTER layer=%d out_proj_rms=%.6f gated_rms=%.6f hidden_rms=%.6f\n", - layer_idx, vec_rms(attn_out, HIDDEN_DIM), - vec_rms(gated_out, LINEAR_TOTAL_VALUE), - vec_rms(hidden, HIDDEN_DIM)); + layer_idx, vec_rms(attn_out, cfg.hidden_dim), + vec_rms(gated_out, cfg.linear_total_value), + vec_rms(hidden, cfg.hidden_dim)); } free(normed); @@ -2993,7 +2946,7 @@ static void linear_attention_forward( static void moe_forward( WeightFile *wf, int layer_idx, - float *hidden, // [HIDDEN_DIM] in/out + float *hidden, // [cfg.hidden_dim] in/out const char *model_path __attribute__((unused)), int K, // number of active experts (e.g. 4) int packed_fd // fd for this layer's packed expert file (-1 if not available) @@ -3003,19 +2956,19 @@ static void moe_forward( int moe_dump = 0; char name[256]; - float *h_post = malloc(HIDDEN_DIM * sizeof(float)); - float *h_mid = malloc(HIDDEN_DIM * sizeof(float)); - cpu_vec_copy(h_mid, hidden, HIDDEN_DIM); + float *h_post = malloc(cfg.hidden_dim * sizeof(float)); + float *h_mid = malloc(cfg.hidden_dim * sizeof(float)); + cpu_vec_copy(h_mid, hidden, cfg.hidden_dim); // ---- Post-attention LayerNorm ---- snprintf(name, sizeof(name), "model.layers.%d.post_attention_layernorm.weight", layer_idx); uint16_t *norm_w = get_tensor_ptr(wf, name); - cpu_rms_norm(hidden, norm_w, h_post, HIDDEN_DIM, RMS_NORM_EPS); + cpu_rms_norm(hidden, norm_w, h_post, cfg.hidden_dim, cfg.rms_norm_eps); // ---- Batch routing gate + shared expert gate/up + shared_expert_gate (4 matmuls, 1 commit) ---- - float *gate_scores = calloc(NUM_EXPERTS, sizeof(float)); - float *shared_gate = calloc(SHARED_INTERMEDIATE, sizeof(float)); - float *shared_up = calloc(SHARED_INTERMEDIATE, sizeof(float)); + float *gate_scores = calloc(cfg.num_experts, sizeof(float)); + float *shared_gate = calloc(cfg.shared_intermediate, sizeof(float)); + float *shared_up = calloc(cfg.shared_intermediate, sizeof(float)); float shared_gate_score = 0.0f; snprintf(name, sizeof(name), "model.layers.%d.mlp.gate.weight", layer_idx); @@ -3050,21 +3003,21 @@ static void moe_forward( if (gate_w && gate_s && gate_b && sgw && sgs && sgb && suw && sus && sub && seg_w && seg_s && seg_b) { BatchMatvecSpec moe_specs[4] = { - { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 }, - { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 }, - { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 }, - { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 }, + { gate_w, gate_s, gate_b, gate_scores, (uint32_t)cfg.num_experts, cfg.hidden_dim, cfg.group_size, 0 }, + { sgw, sgs, sgb, shared_gate, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 1 }, + { suw, sus, sub, shared_up, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 2 }, + { seg_w, seg_s, seg_b, &shared_gate_score, 1, cfg.hidden_dim, cfg.group_size, 3 }, }; - fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4); + fast_batch_matvec(h_post, cfg.hidden_dim, moe_specs, 4); } // Softmax routing scores - cpu_softmax(gate_scores, NUM_EXPERTS); + cpu_softmax(gate_scores, cfg.num_experts); // Top-K expert selection int expert_indices[64]; float expert_weights[64]; - cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights); + cpu_topk(gate_scores, cfg.num_experts, K, expert_indices, expert_weights); cpu_normalize_weights(expert_weights, K); if (moe_dump) { @@ -3074,10 +3027,10 @@ static void moe_forward( } // ---- Routed expert computation ---- - float *moe_out = calloc(HIDDEN_DIM, sizeof(float)); + float *moe_out = calloc(cfg.hidden_dim, sizeof(float)); if (packed_fd >= 0) { - float *expert_out = malloc(HIDDEN_DIM * sizeof(float)); + float *expert_out = malloc(cfg.hidden_dim * sizeof(float)); size_t esz = active_expert_size(); for (int k = 0; k < K; k++) { @@ -3107,26 +3060,26 @@ static void moe_forward( } uint32_t *gw = (uint32_t *)expert_data; - uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : GATE_S_OFF_4)); - uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : GATE_B_OFF_4)); - uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : UP_W_OFF_4)); - uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : UP_S_OFF_4)); - uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : UP_B_OFF_4)); - uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : DOWN_W_OFF_4)); - uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : DOWN_S_OFF_4)); - uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : DOWN_B_OFF_4)); - - float *gate_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); - float *up_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); - float *act_out = malloc(MOE_INTERMEDIATE * sizeof(float)); + uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_s_off_2 : cfg.gate_s_off_4)); + uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_b_off_2 : cfg.gate_b_off_4)); + uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.up_w_off_2 : cfg.up_w_off_4)); + uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_s_off_2 : cfg.up_s_off_4)); + uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_b_off_2 : cfg.up_b_off_4)); + uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.down_w_off_2 : cfg.down_w_off_4)); + uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_s_off_2 : cfg.down_s_off_4)); + uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_b_off_2 : cfg.down_b_off_4)); + + float *gate_proj_out = malloc(cfg.moe_intermediate * sizeof(float)); + float *up_proj_out = malloc(cfg.moe_intermediate * sizeof(float)); + float *act_out = malloc(cfg.moe_intermediate * sizeof(float)); cpu_dequant_matvec(gw, gs_p, gb_p, h_post, gate_proj_out, - MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); + cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size); cpu_dequant_matvec(uw, us_p, ub_p, h_post, up_proj_out, - MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); - cpu_swiglu(gate_proj_out, up_proj_out, act_out, MOE_INTERMEDIATE); + cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size); + cpu_swiglu(gate_proj_out, up_proj_out, act_out, cfg.moe_intermediate); cpu_dequant_matvec(dw, ds_p, db_p, act_out, expert_out, - HIDDEN_DIM, MOE_INTERMEDIATE, GROUP_SIZE); + cfg.hidden_dim, cfg.moe_intermediate, cfg.group_size); free(gate_proj_out); free(up_proj_out); @@ -3137,31 +3090,31 @@ static void moe_forward( // Accumulate weighted if (moe_dump) { fprintf(stderr, "[MOE-DUMP] expert[%d] out_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - eidx, vec_rms(expert_out, HIDDEN_DIM), + eidx, vec_rms(expert_out, cfg.hidden_dim), expert_out[0], expert_out[1], expert_out[2], expert_out[3], expert_out[4]); } - cpu_vec_madd(moe_out, expert_out, expert_weights[k], HIDDEN_DIM); + cpu_vec_madd(moe_out, expert_out, expert_weights[k], cfg.hidden_dim); } free(expert_out); } // ---- Shared expert SwiGLU (gate_proj + up_proj already computed above) ---- - float *shared_out = calloc(HIDDEN_DIM, sizeof(float)); - float *shared_act = calloc(SHARED_INTERMEDIATE, sizeof(float)); - cpu_swiglu(shared_gate, shared_up, shared_act, SHARED_INTERMEDIATE); + float *shared_out = calloc(cfg.hidden_dim, sizeof(float)); + float *shared_act = calloc(cfg.shared_intermediate, sizeof(float)); + cpu_swiglu(shared_gate, shared_up, shared_act, cfg.shared_intermediate); if (moe_dump) { fprintf(stderr, "[MOE-DUMP] layer=%d h_post_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - layer_idx, vec_rms(h_post, HIDDEN_DIM), h_post[0], h_post[1], h_post[2], h_post[3], h_post[4]); + layer_idx, vec_rms(h_post, cfg.hidden_dim), h_post[0], h_post[1], h_post[2], h_post[3], h_post[4]); fprintf(stderr, "[MOE-DUMP] gate_proj_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - vec_rms(shared_gate, SHARED_INTERMEDIATE), + vec_rms(shared_gate, cfg.shared_intermediate), shared_gate[0], shared_gate[1], shared_gate[2], shared_gate[3], shared_gate[4]); fprintf(stderr, "[MOE-DUMP] up_proj_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - vec_rms(shared_up, SHARED_INTERMEDIATE), + vec_rms(shared_up, cfg.shared_intermediate), shared_up[0], shared_up[1], shared_up[2], shared_up[3], shared_up[4]); fprintf(stderr, "[MOE-DUMP] swiglu_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n", - vec_rms(shared_act, SHARED_INTERMEDIATE), + vec_rms(shared_act, cfg.shared_intermediate), shared_act[0], shared_act[1], shared_act[2], shared_act[3], shared_act[4]); } @@ -3173,28 +3126,28 @@ static void moe_forward( snprintf(name, sizeof(name), "model.layers.%d.mlp.shared_expert.down_proj.biases", layer_idx); uint16_t *sdb = get_tensor_ptr(wf, name); if (sdw && sds && sdb) { - fast_dequant_matvec(sdw, sds, sdb, shared_act, shared_out, HIDDEN_DIM, - SHARED_INTERMEDIATE, GROUP_SIZE); + fast_dequant_matvec(sdw, sds, sdb, shared_act, shared_out, cfg.hidden_dim, + cfg.shared_intermediate, cfg.group_size); } // ---- Shared expert gate (sigmoid) -- already computed above ---- float shared_weight = cpu_sigmoid(shared_gate_score); // Scale shared expert output - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { shared_out[i] *= shared_weight; } // ---- Combine: hidden = h_mid + moe_out + shared_out ---- - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { hidden[i] = h_mid[i] + moe_out[i] + shared_out[i]; } if (moe_debug) { fprintf(stderr, "[MOE-DBG] layer=%d h_mid_rms=%.4f moe_rms=%.4f shared_rms=%.4f shared_gate=%.4f hidden_rms=%.4f\n", - layer_idx, vec_rms(h_mid, HIDDEN_DIM), vec_rms(moe_out, HIDDEN_DIM), - vec_rms(shared_out, HIDDEN_DIM), shared_weight, - vec_rms(hidden, HIDDEN_DIM)); + layer_idx, vec_rms(h_mid, cfg.hidden_dim), vec_rms(moe_out, cfg.hidden_dim), + vec_rms(shared_out, cfg.hidden_dim), shared_weight, + vec_rms(hidden, cfg.hidden_dim)); } free(h_post); @@ -3223,7 +3176,7 @@ static void embed_lookup(WeightFile *wf, int token_id, float *out) { if (!w_info || !s_info || !b_info) { fprintf(stderr, "ERROR: embedding tensors not found\n"); - memset(out, 0, HIDDEN_DIM * sizeof(float)); + memset(out, 0, cfg.hidden_dim * sizeof(float)); return; } @@ -3239,7 +3192,7 @@ static void embed_lookup(WeightFile *wf, int token_id, float *out) { const uint16_t *s_row = S + (size_t)token_id * num_groups; const uint16_t *b_row = B + (size_t)token_id * num_groups; - int group_size = HIDDEN_DIM / num_groups; // 4096/64 = 64 + int group_size = cfg.hidden_dim / num_groups; // 4096/64 = 64 int packed_per_group = group_size / 8; // 8 for (int g = 0; g < num_groups; g++) { @@ -3281,7 +3234,7 @@ static void lm_head_forward(WeightFile *wf, const float *hidden, float *logits) uint16_t *B = (uint16_t *)((char *)wf->data + b_info->offset); // Full matmul — use GPU if available (248320 output rows!) - fast_dequant_matvec(W, S, B, hidden, logits, VOCAB_SIZE, HIDDEN_DIM, GROUP_SIZE); + fast_dequant_matvec(W, S, B, hidden, logits, cfg.vocab_size, cfg.hidden_dim, cfg.group_size); } // ============================================================================ @@ -3557,7 +3510,7 @@ static int parallel_pread_experts_into( typedef struct { int layer_idx; int expert_idx; - id buffer; // Metal buffer holding EXPERT_SIZE bytes + id buffer; // Metal buffer holding cfg.expert_size_4bit bytes uint64_t last_used; // monotonic counter for LRU ordering } ExpertCacheEntry; @@ -3566,7 +3519,7 @@ static int parallel_pread_experts_into( int max_entries; int num_entries; int used_entries; - int entry_idx[NUM_LAYERS][NUM_EXPERTS]; + int *entry_idx; // flattened [num_layers * num_experts], -1 = not cached uint64_t access_counter; // monotonic, incremented on every access id device; // for allocating new Metal buffers // Stats @@ -3589,14 +3542,13 @@ static int parallel_pread_experts_into( // - Loads into scratch buffers (no cache pollution) // - Uses CMD1_wait idle time (no additional CPU cost) // - Only sync-preads misses (not all K experts) -static int g_pred_experts[60][MAX_K]; // previous token's expert indices per layer -static int g_pred_count[60]; // how many experts stored per layer static int g_pred_valid = 0; // 1 after first token completes (predictions available) // g_pred_enabled, g_pred_hits, g_pred_misses, g_pred_layers declared near timing (line ~163) static ExpertLRUCache *expert_cache_new(id device, int max_entries) { ExpertLRUCache *cache = calloc(1, sizeof(ExpertLRUCache)); cache->entries = calloc(max_entries, sizeof(ExpertCacheEntry)); + cache->entry_idx = malloc(cfg.num_layers * cfg.num_experts * sizeof(int)); cache->max_entries = max_entries; cache->num_entries = 0; cache->used_entries = 0; @@ -3604,9 +3556,9 @@ static int parallel_pread_experts_into( cache->device = device; cache->hits = 0; cache->misses = 0; - for (int l = 0; l < NUM_LAYERS; l++) { - for (int e = 0; e < NUM_EXPERTS; e++) { - cache->entry_idx[l][e] = -1; + for (int l = 0; l < cfg.num_layers; l++) { + for (int e = 0; e < cfg.num_experts; e++) { + cache->entry_idx[(l) * cfg.num_experts + (e)] = -1; } } // Pre-allocate ALL Metal buffers at startup (avoids allocation overhead at runtime) @@ -3645,7 +3597,7 @@ static void expert_cache_free(ExpertLRUCache *cache) { // Lookup: returns the cached Metal buffer if found, otherwise NULL. // On hit, updates the LRU timestamp. static id expert_cache_lookup(ExpertLRUCache *cache, int layer_idx, int expert_idx) { - int idx = cache->entry_idx[layer_idx][expert_idx]; + int idx = cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)]; if (idx >= 0) { cache->entries[idx].last_used = ++cache->access_counter; cache->hits++; @@ -3662,7 +3614,7 @@ static void expert_cache_free(ExpertLRUCache *cache) { static id expert_cache_insert(ExpertLRUCache *cache, int layer_idx, int expert_idx) { id buf = nil; - int existing = cache->entry_idx[layer_idx][expert_idx]; + int existing = cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)]; if (existing >= 0) { cache->entries[existing].last_used = ++cache->access_counter; return cache->entries[existing].buffer; @@ -3679,7 +3631,7 @@ static void expert_cache_free(ExpertLRUCache *cache) { cache->entries[target].layer_idx = layer_idx; cache->entries[target].expert_idx = expert_idx; cache->entries[target].last_used = ++cache->access_counter; - cache->entry_idx[layer_idx][expert_idx] = target; + cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)] = target; return buf; } @@ -3698,13 +3650,13 @@ static void expert_cache_free(ExpertLRUCache *cache) { int old_expert = cache->entries[lru_idx].expert_idx; cache_telemetry_evict(old_layer, old_expert); if (old_layer >= 0 && old_expert >= 0) { - cache->entry_idx[old_layer][old_expert] = -1; + cache->entry_idx[(old_layer) * cfg.num_experts + (old_expert)] = -1; } buf = cache->entries[lru_idx].buffer; cache->entries[lru_idx].layer_idx = layer_idx; cache->entries[lru_idx].expert_idx = expert_idx; cache->entries[lru_idx].last_used = ++cache->access_counter; - cache->entry_idx[layer_idx][expert_idx] = lru_idx; + cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)] = lru_idx; return buf; } @@ -3716,7 +3668,7 @@ static void expert_cache_free(ExpertLRUCache *cache) { // ============================================================================ typedef struct { - void **data; // [max_entries] page-aligned malloc'd EXPERT_SIZE buffers + void **data; // [max_entries] page-aligned malloc'd cfg.expert_size_4bit buffers id __strong *metal_bufs; // [max_entries] zero-copy Metal buffer wrappers int *layer_idx; // [max_entries] layer index for each entry int *expert_idx; // [max_entries] expert index for each entry @@ -3724,7 +3676,7 @@ static void expert_cache_free(ExpertLRUCache *cache) { int max_entries; int num_entries; int used_entries; - int entry_idx[NUM_LAYERS][NUM_EXPERTS]; + int *entry_idx; // flattened [num_layers * num_experts], -1 = not cached uint64_t access_counter; uint64_t hits; uint64_t misses; @@ -3739,15 +3691,16 @@ static void expert_cache_free(ExpertLRUCache *cache) { cache->layer_idx = calloc(max_entries, sizeof(int)); cache->expert_idx = calloc(max_entries, sizeof(int)); cache->last_used = calloc(max_entries, sizeof(uint64_t)); + cache->entry_idx = malloc(cfg.num_layers * cfg.num_experts * sizeof(int)); cache->max_entries = max_entries; cache->num_entries = 0; cache->used_entries = 0; cache->access_counter = 0; cache->hits = 0; cache->misses = 0; - for (int l = 0; l < NUM_LAYERS; l++) { - for (int e = 0; e < NUM_EXPERTS; e++) { - cache->entry_idx[l][e] = -1; + for (int l = 0; l < cfg.num_layers; l++) { + for (int e = 0; e < cfg.num_experts; e++) { + cache->entry_idx[(l) * cfg.num_experts + (e)] = -1; } } @@ -3791,7 +3744,7 @@ static void expert_cache_free(ExpertLRUCache *cache) { // Lookup: returns Metal buffer wrapping cached data, or nil. Zero-copy dispatch. static id malloc_cache_lookup(MallocExpertCache *cache, int layer, int expert) { - int idx = cache->entry_idx[layer][expert]; + int idx = cache->entry_idx[(layer) * cfg.num_experts + (expert)]; if (idx >= 0) { cache->last_used[idx] = ++cache->access_counter; cache->hits++; @@ -3806,7 +3759,7 @@ static void expert_cache_free(ExpertLRUCache *cache) { // Insert: evict LRU if needed, return entry index for pread target. // Returns the Metal buffer for this entry (caller should pread into cache->data[idx]). static id malloc_cache_insert(MallocExpertCache *cache, int layer, int expert, int *out_idx) { - int existing = cache->entry_idx[layer][expert]; + int existing = cache->entry_idx[(layer) * cfg.num_experts + (expert)]; if (existing >= 0) { cache->last_used[existing] = ++cache->access_counter; if (out_idx) *out_idx = existing; @@ -3831,14 +3784,14 @@ static void expert_cache_free(ExpertLRUCache *cache) { } cache_telemetry_evict(cache->layer_idx[target], cache->expert_idx[target]); if (cache->layer_idx[target] >= 0 && cache->expert_idx[target] >= 0) { - cache->entry_idx[cache->layer_idx[target]][cache->expert_idx[target]] = -1; + cache->entry_idx[(cache->layer_idx[target]) * cfg.num_experts + (cache->expert_idx[target])] = -1; } } cache->layer_idx[target] = layer; cache->expert_idx[target] = expert; cache->last_used[target] = ++cache->access_counter; - cache->entry_idx[layer][expert] = target; + cache->entry_idx[(layer) * cfg.num_experts + (expert)] = target; if (out_idx) *out_idx = target; return cache->metal_bufs[target]; } @@ -4028,16 +3981,33 @@ static void infer_prefetch_shutdown(void) { uint32_t *seg_w; uint16_t *seg_s, *seg_b; // shared_expert_gate } LayerWeightCache; -static LayerWeightCache layer_cache[NUM_LAYERS]; +static LayerWeightCache *layer_cache = NULL; static int layer_cache_built = 0; +// Allocate all dynamic tracking arrays (must be called after load_model_config) +static void alloc_tracking_arrays(void) { + int nl = cfg.num_layers; + int ne = cfg.num_experts; + int seen_bytes_per_layer = (ne + 7) / 8; + + g_expert_freq = calloc(nl * ne, sizeof(int)); + g_expert_seen = calloc(nl * seen_bytes_per_layer, sizeof(uint8_t)); + g_lz4_index = calloc(nl, sizeof(void *)); + g_cache_seen = calloc(nl * ne, sizeof(uint8_t)); + g_cache_last_touch_token = calloc(nl * ne, sizeof(uint64_t)); + g_cache_last_evict_token = calloc(nl * ne, sizeof(uint64_t)); + g_pred_experts = calloc(nl * MAX_K, sizeof(int)); + g_pred_count = calloc(nl, sizeof(int)); + layer_cache = calloc(nl, sizeof(LayerWeightCache)); +} + static void build_layer_cache(WeightFile *wf) { if (layer_cache_built) return; char name[256]; - for (int i = 0; i < NUM_LAYERS; i++) { + for (int i = 0; i < cfg.num_layers; i++) { LayerWeightCache *lc = &layer_cache[i]; - int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0); + int is_full = cfg.is_full_attn[i]; // Norms snprintf(name, sizeof(name), "model.layers.%d.input_layernorm.weight", i); @@ -4151,7 +4121,7 @@ static void build_layer_cache(WeightFile *wf) { } layer_cache_built = 1; - printf("[cache] Pre-computed weight pointers for %d layers\n", NUM_LAYERS); + printf("[cache] Pre-computed weight pointers for %d layers\n", cfg.num_layers); } // ============================================================================ @@ -4170,13 +4140,13 @@ static void build_layer_cache(WeightFile *wf) { float expert_weights[MAX_K]; // routing weights for weighted accumulation int valid[MAX_K]; // which experts loaded successfully int actual_K; // number of experts - float h_mid[HIDDEN_DIM]; // saved h_mid for final combine + float *h_mid; // [hidden_dim] saved h_mid for final combine float shared_gate_score; // saved shared expert gate score float *hidden; // pointer to hidden state (for writing final result) int layer_idx; // which layer produced this deferred state } DeferredExpertState; -static DeferredExpertState g_deferred = { .active = 0 }; +static DeferredExpertState g_deferred = { .active = 0, .h_mid = NULL }; // Wait for the deferred GPU expert command buffer to complete. // Split from finalize so timing can be measured independently. @@ -4197,30 +4167,30 @@ static void finalize_deferred_experts(void) { // buf_input already has the normalized input for the next layer's CMD1. // Just read back hidden (needed for the residual connection in future layers). memcpy(g_deferred.hidden, [g_metal->buf_moe_hidden contents], - HIDDEN_DIM * sizeof(float)); + cfg.hidden_dim * sizeof(float)); } else { // CPU-side combine (original path) // Read back and accumulate routed expert outputs - float moe_out[HIDDEN_DIM]; + float moe_out[cfg.hidden_dim]; memset(moe_out, 0, sizeof(moe_out)); for (int k = 0; k < g_deferred.actual_K; k++) { if (!g_deferred.valid[k]) continue; float *expert_result = (float *)[g_metal->buf_multi_expert_out[k] contents]; - cpu_vec_madd(moe_out, expert_result, g_deferred.expert_weights[k], HIDDEN_DIM); + cpu_vec_madd(moe_out, expert_result, g_deferred.expert_weights[k], cfg.hidden_dim); } // Read shared expert result - float shared_out[HIDDEN_DIM]; - memcpy(shared_out, [g_metal->buf_shared_out contents], HIDDEN_DIM * sizeof(float)); + float shared_out[cfg.hidden_dim]; + memcpy(shared_out, [g_metal->buf_shared_out contents], cfg.hidden_dim * sizeof(float)); // Apply shared expert gate float shared_weight = cpu_sigmoid(g_deferred.shared_gate_score); - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { shared_out[i] *= shared_weight; } // Final combine: hidden = h_mid + moe_out + shared_out - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { g_deferred.hidden[i] = g_deferred.h_mid[i] + moe_out[i] + shared_out[i]; } } @@ -4289,67 +4259,67 @@ static void discard_deferred_experts(void) { // Static scratch buffers — allocated once, reused across all 40 layers per token. // Eliminates ~20 malloc/free per layer = ~1200 alloc/free per token. -static float *s_normed = NULL; // [HIDDEN_DIM] -static float *s_residual = NULL; // [HIDDEN_DIM] -static float *s_attn_proj = NULL; // [HIDDEN_DIM] -static float *s_h_post = NULL; // [HIDDEN_DIM] -static float *s_h_mid = NULL; // [HIDDEN_DIM] -static float *s_gate_scores = NULL; // [NUM_EXPERTS] -static float *s_spec_gate_scores = NULL; // [NUM_EXPERTS] speculative routing scratch +static float *s_normed = NULL; // [cfg.hidden_dim] +static float *s_residual = NULL; // [cfg.hidden_dim] +static float *s_attn_proj = NULL; // [cfg.hidden_dim] +static float *s_h_post = NULL; // [cfg.hidden_dim] +static float *s_h_mid = NULL; // [cfg.hidden_dim] +static float *s_gate_scores = NULL; // [cfg.num_experts] +static float *s_spec_gate_scores = NULL; // [cfg.num_experts] speculative routing scratch static int s_spec_indices[MAX_K]; // speculative routing predicted expert indices static int s_spec_count = 0; // number of speculative predictions this layer -static float *s_shared_gate = NULL; // [SHARED_INTERMEDIATE] -static float *s_shared_up = NULL; // [SHARED_INTERMEDIATE] -static float *s_moe_out = NULL; // [HIDDEN_DIM] -static float *s_shared_out = NULL; // [HIDDEN_DIM] +static float *s_shared_gate = NULL; // [cfg.shared_intermediate] +static float *s_shared_up = NULL; // [cfg.shared_intermediate] +static float *s_moe_out = NULL; // [cfg.hidden_dim] +static float *s_shared_out = NULL; // [cfg.hidden_dim] // Full attention scratch -static float *s_q_proj_out = NULL; // [NUM_ATTN_HEADS * HEAD_DIM * 2] -static float *s_k_proj_out = NULL; // [NUM_KV_HEADS * HEAD_DIM] -static float *s_v_proj_out = NULL; // [NUM_KV_HEADS * HEAD_DIM] -static float *s_q = NULL; // [NUM_ATTN_HEADS * HEAD_DIM] -static float *s_q_gate = NULL; // [NUM_ATTN_HEADS * HEAD_DIM] -static float *s_attn_out = NULL; // [NUM_ATTN_HEADS * HEAD_DIM] +static float *s_q_proj_out = NULL; // [cfg.num_attn_heads * cfg.head_dim * 2] +static float *s_k_proj_out = NULL; // [cfg.num_kv_heads * cfg.head_dim] +static float *s_v_proj_out = NULL; // [cfg.num_kv_heads * cfg.head_dim] +static float *s_q = NULL; // [cfg.num_attn_heads * cfg.head_dim] +static float *s_q_gate = NULL; // [cfg.num_attn_heads * cfg.head_dim] +static float *s_attn_out = NULL; // [cfg.num_attn_heads * cfg.head_dim] // Linear attention scratch -static float *s_qkv_proj_out = NULL; // [LINEAR_CONV_DIM] -static float *s_z_proj_out = NULL; // [LINEAR_TOTAL_VALUE] -static float *s_beta_proj_out = NULL; // [LINEAR_NUM_V_HEADS] -static float *s_alpha_proj_out = NULL; // [LINEAR_NUM_V_HEADS] -static float *s_conv_out = NULL; // [LINEAR_CONV_DIM] -static float *s_out_vals = NULL; // [LINEAR_TOTAL_VALUE] -static float *s_gated_out = NULL; // [LINEAR_TOTAL_VALUE] +static float *s_qkv_proj_out = NULL; // [cfg.linear_conv_dim] +static float *s_z_proj_out = NULL; // [cfg.linear_total_value] +static float *s_beta_proj_out = NULL; // [cfg.linear_num_v_heads] +static float *s_alpha_proj_out = NULL; // [cfg.linear_num_v_heads] +static float *s_conv_out = NULL; // [cfg.linear_conv_dim] +static float *s_out_vals = NULL; // [cfg.linear_total_value] +static float *s_gated_out = NULL; // [cfg.linear_total_value] static void init_layer_scratch(void) { if (s_normed) return; // already initialized - s_normed = calloc(HIDDEN_DIM, sizeof(float)); - s_residual = calloc(HIDDEN_DIM, sizeof(float)); - s_attn_proj = calloc(HIDDEN_DIM, sizeof(float)); - s_h_post = calloc(HIDDEN_DIM, sizeof(float)); - s_h_mid = calloc(HIDDEN_DIM, sizeof(float)); - s_gate_scores = calloc(NUM_EXPERTS, sizeof(float)); - s_spec_gate_scores = calloc(NUM_EXPERTS, sizeof(float)); - s_shared_gate = calloc(SHARED_INTERMEDIATE, sizeof(float)); - s_shared_up = calloc(SHARED_INTERMEDIATE, sizeof(float)); - s_moe_out = calloc(HIDDEN_DIM, sizeof(float)); - s_shared_out = calloc(HIDDEN_DIM, sizeof(float)); - s_q_proj_out = calloc(NUM_ATTN_HEADS * HEAD_DIM * 2, sizeof(float)); - s_k_proj_out = calloc(NUM_KV_HEADS * HEAD_DIM, sizeof(float)); - s_v_proj_out = calloc(NUM_KV_HEADS * HEAD_DIM, sizeof(float)); - s_q = calloc(NUM_ATTN_HEADS * HEAD_DIM, sizeof(float)); - s_q_gate = calloc(NUM_ATTN_HEADS * HEAD_DIM, sizeof(float)); - s_attn_out = calloc(NUM_ATTN_HEADS * HEAD_DIM, sizeof(float)); - s_qkv_proj_out = calloc(LINEAR_CONV_DIM, sizeof(float)); - s_z_proj_out = calloc(LINEAR_TOTAL_VALUE, sizeof(float)); - s_beta_proj_out = calloc(LINEAR_NUM_V_HEADS, sizeof(float)); - s_alpha_proj_out = calloc(LINEAR_NUM_V_HEADS, sizeof(float)); - s_conv_out = calloc(LINEAR_CONV_DIM, sizeof(float)); - s_out_vals = calloc(LINEAR_TOTAL_VALUE, sizeof(float)); - s_gated_out = calloc(LINEAR_TOTAL_VALUE, sizeof(float)); + s_normed = calloc(cfg.hidden_dim, sizeof(float)); + s_residual = calloc(cfg.hidden_dim, sizeof(float)); + s_attn_proj = calloc(cfg.hidden_dim, sizeof(float)); + s_h_post = calloc(cfg.hidden_dim, sizeof(float)); + s_h_mid = calloc(cfg.hidden_dim, sizeof(float)); + s_gate_scores = calloc(cfg.num_experts, sizeof(float)); + s_spec_gate_scores = calloc(cfg.num_experts, sizeof(float)); + s_shared_gate = calloc(cfg.shared_intermediate, sizeof(float)); + s_shared_up = calloc(cfg.shared_intermediate, sizeof(float)); + s_moe_out = calloc(cfg.hidden_dim, sizeof(float)); + s_shared_out = calloc(cfg.hidden_dim, sizeof(float)); + s_q_proj_out = calloc(cfg.num_attn_heads * cfg.head_dim * 2, sizeof(float)); + s_k_proj_out = calloc(cfg.num_kv_heads * cfg.head_dim, sizeof(float)); + s_v_proj_out = calloc(cfg.num_kv_heads * cfg.head_dim, sizeof(float)); + s_q = calloc(cfg.num_attn_heads * cfg.head_dim, sizeof(float)); + s_q_gate = calloc(cfg.num_attn_heads * cfg.head_dim, sizeof(float)); + s_attn_out = calloc(cfg.num_attn_heads * cfg.head_dim, sizeof(float)); + s_qkv_proj_out = calloc(cfg.linear_conv_dim, sizeof(float)); + s_z_proj_out = calloc(cfg.linear_total_value, sizeof(float)); + s_beta_proj_out = calloc(cfg.linear_num_v_heads, sizeof(float)); + s_alpha_proj_out = calloc(cfg.linear_num_v_heads, sizeof(float)); + s_conv_out = calloc(cfg.linear_conv_dim, sizeof(float)); + s_out_vals = calloc(cfg.linear_total_value, sizeof(float)); + s_gated_out = calloc(cfg.linear_total_value, sizeof(float)); } static void fused_layer_forward( WeightFile *wf, int layer_idx, - float *hidden, // [HIDDEN_DIM] in/out + float *hidden, // [cfg.hidden_dim] in/out KVCache *kv, // non-NULL for full attention layers LinearAttnState *la_state, // non-NULL for linear attention layers int pos, // position for RoPE @@ -4377,8 +4347,8 @@ static void fused_layer_forward( float *qkv_out = NULL, *z_out = NULL, *beta_out = NULL, *alpha_out = NULL; if (is_full) { - int q_proj_dim = NUM_ATTN_HEADS * HEAD_DIM * 2; - int kv_dim = NUM_KV_HEADS * HEAD_DIM; + int q_proj_dim = cfg.num_attn_heads * cfg.head_dim * 2; + int kv_dim = cfg.num_kv_heads * cfg.head_dim; q_proj_out = s_q_proj_out; k_out = s_k_proj_out; @@ -4386,14 +4356,14 @@ static void fused_layer_forward( if (lc->q_w && lc->q_s && lc->q_b && lc->k_w && lc->k_s && lc->k_b && lc->v_w && lc->v_s && lc->v_b) { - attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 }; - attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 }; - attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 }; + attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, cfg.hidden_dim, cfg.group_size, 0 }; + attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 1 }; + attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 2 }; num_attn_specs = 3; } } else { - int qkv_dim = LINEAR_CONV_DIM; - int z_dim = LINEAR_TOTAL_VALUE; + int qkv_dim = cfg.linear_conv_dim; + int z_dim = cfg.linear_total_value; qkv_out = s_qkv_proj_out; z_out = s_z_proj_out; @@ -4402,10 +4372,10 @@ static void fused_layer_forward( if (lc->qkv_w && lc->qkv_s && lc->qkv_b && lc->z_w && lc->z_s && lc->z_b && lc->b_w && lc->b_s && lc->b_b && lc->a_w && lc->a_s && lc->a_b) { - attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 }; - attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 }; - attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 }; - attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 }; + attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, cfg.hidden_dim, cfg.group_size, 0 }; + attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, cfg.hidden_dim, cfg.group_size, 1 }; + attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 2 }; + attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 3 }; num_attn_specs = 4; } } @@ -4419,7 +4389,7 @@ static void fused_layer_forward( // Pre-compute linear_layer_idx for GPU linear attention encoding in CMD1 int linear_layer_idx = -1; if (!is_full) { - linear_layer_idx = layer_idx - (layer_idx + 1) / FULL_ATTN_INTERVAL; + linear_layer_idx = cfg.linear_index[layer_idx]; } // Can we run the full linear attention pipeline on GPU in CMD1? int can_gpu_linear = (gpu_linear_attn_enabled && @@ -4427,7 +4397,7 @@ static void fused_layer_forward( g_metal->conv1d_step && g_metal->rms_norm_qk && g_metal->compute_decay_beta && g_metal->gated_rms_norm && g_metal->wf_buf && - linear_layer_idx >= 0 && linear_layer_idx < NUM_LINEAR_LAYERS && + linear_layer_idx >= 0 && linear_layer_idx < cfg.num_linear_layers && lc->conv1d_w && lc->A_log && lc->dt_bias && lc->gated_norm_w && !linear_attn_bypass); @@ -4448,7 +4418,7 @@ static void fused_layer_forward( // GPU linear attention: encode conv1d + normalize + decay/beta + delta-net + gated_norm into CMD1 if (can_gpu_linear && num_attn_specs == 4) { // batch_out[0]=qkv(12288), [1]=z(8192), [2]=beta(64), [3]=alpha(64) - uint32_t conv_dim = LINEAR_CONV_DIM; + uint32_t conv_dim = cfg.linear_conv_dim; NSUInteger conv_w_off = (NSUInteger)((const char *)lc->conv1d_w - (const char *)[g_metal->wf_buf contents]); // Enc L1: conv1d_step — input=batch_out[0], weights=conv1d_w, state=buf_conv_state, output=buf_conv_output @@ -4468,16 +4438,16 @@ static void fused_layer_forward( // Enc L2: rms_norm_qk — normalize q and k in conv_output in-place { - uint32_t key_dim = LINEAR_KEY_DIM; // 128 - float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM); + uint32_t key_dim = cfg.linear_key_dim; // 128 + float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim); id enc = [cmd1 computeCommandEncoder]; [enc setComputePipelineState:g_metal->rms_norm_qk]; [enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:0]; // q at offset 0 - [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:1]; // k at offset 2048 floats + [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:1]; // k at offset 2048 floats [enc setBytes:&key_dim length:4 atIndex:2]; [enc setBytes:&inv_scale length:4 atIndex:3]; - [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_K_HEADS, 1, 1) - threadsPerThreadgroup:MTLSizeMake(LINEAR_KEY_DIM, 1, 1)]; + [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_k_heads, 1, 1) + threadsPerThreadgroup:MTLSizeMake(cfg.linear_key_dim, 1, 1)]; [enc endEncoding]; } @@ -4494,24 +4464,24 @@ static void fused_layer_forward( [enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4]; // g_decay output [enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5]; // beta_gate output [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) - threadsPerThreadgroup:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)]; + threadsPerThreadgroup:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)]; [enc endEncoding]; } // Enc L4: gated_delta_net_step — the main recurrence { - uint32_t khpv = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS; // 4 + uint32_t khpv = cfg.linear_num_v_heads / cfg.linear_num_k_heads; // 4 id enc = [cmd1 computeCommandEncoder]; [enc setComputePipelineState:g_metal->delta_net_step]; [enc setBuffer:g_metal->buf_delta_state[linear_layer_idx] offset:0 atIndex:0]; // persistent state [enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:1]; // q (first 2048 floats) - [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:2]; // k (next 2048) - [enc setBuffer:g_metal->buf_conv_output offset:2 * LINEAR_TOTAL_KEY * sizeof(float) atIndex:3]; // v (next 8192) + [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:2]; // k (next 2048) + [enc setBuffer:g_metal->buf_conv_output offset:2 * cfg.linear_total_key * sizeof(float) atIndex:3]; // v (next 8192) [enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4]; [enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5]; [enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:6]; // output [8192] [enc setBytes:&khpv length:sizeof(khpv) atIndex:7]; - [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1) + [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [enc endEncoding]; } @@ -4519,8 +4489,8 @@ static void fused_layer_forward( // Enc L5: gated_rms_norm — normalize+gate delta-net output -> batch_out[6] for CMD2 o_proj { NSUInteger gnorm_w_off = (NSUInteger)((const char *)lc->gated_norm_w - (const char *)[g_metal->wf_buf contents]); - uint32_t value_dim = LINEAR_VALUE_DIM; // 128 - float eps = RMS_NORM_EPS; + uint32_t value_dim = cfg.linear_value_dim; // 128 + float eps = cfg.rms_norm_eps; id enc = [cmd1 computeCommandEncoder]; [enc setComputePipelineState:g_metal->gated_rms_norm]; [enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:0]; // values [8192] @@ -4529,8 +4499,8 @@ static void fused_layer_forward( [enc setBuffer:g_metal->batch_out[6] offset:0 atIndex:3]; // output -> batch_out[6] for CMD2 [enc setBytes:&value_dim length:4 atIndex:4]; [enc setBytes:&eps length:4 atIndex:5]; - [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1) - threadsPerThreadgroup:MTLSizeMake(LINEAR_VALUE_DIM, 1, 1)]; + [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1) + threadsPerThreadgroup:MTLSizeMake(cfg.linear_value_dim, 1, 1)]; [enc endEncoding]; } @@ -4558,14 +4528,14 @@ static void fused_layer_forward( // Predictions overlap with CPU attn + CMD2 + routing (~0.6ms head start). // Predicted experts that hit page cache (same as previous token) complete in ~0.1ms. if (g_pred_enabled && g_pred_generating && g_pred_valid && packed_fd >= 0 && - g_metal->buf_multi_expert_data_B[0] && g_pred_count[layer_idx] > 0) { - async_pread_start(packed_fd, g_pred_experts[layer_idx], - g_pred_count[layer_idx], + g_metal->buf_multi_expert_data_B[0] && PRED_COUNT(layer_idx) > 0) { + async_pread_start(packed_fd, &PRED_EXPERT(layer_idx, 0), + PRED_COUNT(layer_idx), g_metal->buf_multi_expert_data_B, mmap_base); pred_started = 1; } // Set up residual for CMD2 (residual = hidden before this layer's attention) - cpu_vec_copy(residual, hidden, HIDDEN_DIM); + cpu_vec_copy(residual, hidden, cfg.hidden_dim); if (g_timing_enabled) { t1 = now_ms(); g_timing.deferred_cpu += t1 - t0; } // No input_norm needed — CMD3 already computed it into buf_input. @@ -4584,20 +4554,20 @@ static void fused_layer_forward( // Input norm if (g_timing_enabled) { t0 = now_ms(); } - cpu_vec_copy(residual, hidden, HIDDEN_DIM); - cpu_rms_norm(hidden, lc->input_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS); + cpu_vec_copy(residual, hidden, cfg.hidden_dim); + cpu_rms_norm(hidden, lc->input_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps); if (g_timing_enabled) { t1 = now_ms(); g_timing.input_norm += t1 - t0; } // Submit CMD1: attention projections if (g_timing_enabled) { t0 = now_ms(); } if (g_metal && g_metal->wf_buf && num_attn_specs > 0) { - memcpy([g_metal->buf_input contents], normed, HIDDEN_DIM * sizeof(float)); + memcpy([g_metal->buf_input contents], normed, cfg.hidden_dim * sizeof(float)); cmd1 = [g_metal->queue commandBuffer]; gpu_encode_batch_matvec(g_metal, cmd1, attn_specs, num_attn_specs); // GPU linear attention: encode conv1d + normalize + decay/beta + delta-net + gated_norm into CMD1 if (can_gpu_linear && num_attn_specs == 4) { - uint32_t conv_dim = LINEAR_CONV_DIM; + uint32_t conv_dim = cfg.linear_conv_dim; NSUInteger conv_w_off = (NSUInteger)((const char *)lc->conv1d_w - (const char *)[g_metal->wf_buf contents]); // Enc L1: conv1d_step @@ -4617,16 +4587,16 @@ static void fused_layer_forward( // Enc L2: rms_norm_qk { - uint32_t key_dim = LINEAR_KEY_DIM; - float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM); + uint32_t key_dim = cfg.linear_key_dim; + float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim); id enc = [cmd1 computeCommandEncoder]; [enc setComputePipelineState:g_metal->rms_norm_qk]; [enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:0]; - [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:1]; + [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:1]; [enc setBytes:&key_dim length:4 atIndex:2]; [enc setBytes:&inv_scale length:4 atIndex:3]; - [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_K_HEADS, 1, 1) - threadsPerThreadgroup:MTLSizeMake(LINEAR_KEY_DIM, 1, 1)]; + [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_k_heads, 1, 1) + threadsPerThreadgroup:MTLSizeMake(cfg.linear_key_dim, 1, 1)]; [enc endEncoding]; } @@ -4643,24 +4613,24 @@ static void fused_layer_forward( [enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4]; [enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5]; [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) - threadsPerThreadgroup:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)]; + threadsPerThreadgroup:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)]; [enc endEncoding]; } // Enc L4: gated_delta_net_step { - uint32_t khpv = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS; + uint32_t khpv = cfg.linear_num_v_heads / cfg.linear_num_k_heads; id enc = [cmd1 computeCommandEncoder]; [enc setComputePipelineState:g_metal->delta_net_step]; [enc setBuffer:g_metal->buf_delta_state[linear_layer_idx] offset:0 atIndex:0]; [enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:1]; - [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:2]; - [enc setBuffer:g_metal->buf_conv_output offset:2 * LINEAR_TOTAL_KEY * sizeof(float) atIndex:3]; + [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:2]; + [enc setBuffer:g_metal->buf_conv_output offset:2 * cfg.linear_total_key * sizeof(float) atIndex:3]; [enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4]; [enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5]; [enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:6]; [enc setBytes:&khpv length:sizeof(khpv) atIndex:7]; - [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1) + [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [enc endEncoding]; } @@ -4668,8 +4638,8 @@ static void fused_layer_forward( // Enc L5: gated_rms_norm -> batch_out[6] { NSUInteger gnorm_w_off = (NSUInteger)((const char *)lc->gated_norm_w - (const char *)[g_metal->wf_buf contents]); - uint32_t value_dim = LINEAR_VALUE_DIM; - float eps = RMS_NORM_EPS; + uint32_t value_dim = cfg.linear_value_dim; + float eps = cfg.rms_norm_eps; id enc = [cmd1 computeCommandEncoder]; [enc setComputePipelineState:g_metal->gated_rms_norm]; [enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:0]; @@ -4678,8 +4648,8 @@ static void fused_layer_forward( [enc setBuffer:g_metal->batch_out[6] offset:0 atIndex:3]; [enc setBytes:&value_dim length:4 atIndex:4]; [enc setBytes:&eps length:4 atIndex:5]; - [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1) - threadsPerThreadgroup:MTLSizeMake(LINEAR_VALUE_DIM, 1, 1)]; + [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1) + threadsPerThreadgroup:MTLSizeMake(cfg.linear_value_dim, 1, 1)]; [enc endEncoding]; } @@ -4726,17 +4696,17 @@ static void fused_layer_forward( if (spec_routing_enabled && (g_expert_cache || g_malloc_cache) && packed_fd >= 0 && lc->gate_w) { float *spec_scores = s_spec_gate_scores; - memset(spec_scores, 0, NUM_EXPERTS * sizeof(float)); + memset(spec_scores, 0, cfg.num_experts * sizeof(float)); // Gate projection matvec on pre-attention normed input (CPU, ~0.1ms for 512x4096) cpu_dequant_matvec(lc->gate_w, lc->gate_s, lc->gate_b, normed, spec_scores, - NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE); - cpu_softmax(spec_scores, NUM_EXPERTS); + cfg.num_experts, cfg.hidden_dim, cfg.group_size); + cpu_softmax(spec_scores, cfg.num_experts); int spec_K = (K > MAX_K) ? MAX_K : K; float spec_weights[MAX_K]; - cpu_topk(spec_scores, NUM_EXPERTS, spec_K, s_spec_indices, spec_weights); + cpu_topk(spec_scores, cfg.num_experts, spec_K, s_spec_indices, spec_weights); s_spec_count = spec_K; g_spec_route_attempts += spec_K; @@ -4801,7 +4771,7 @@ static void fused_layer_forward( if (g_timing_enabled) { t0 = now_ms(); } float *attn_projected = s_attn_proj; - memset(attn_projected, 0, HIDDEN_DIM * sizeof(float)); + memset(attn_projected, 0, cfg.hidden_dim * sizeof(float)); // Pre-lookup o_proj / out_proj weights (used after attention compute) // These are looked up NOW to avoid repeated snprintf later. @@ -4811,10 +4781,10 @@ static void fused_layer_forward( if (is_full) { oproj_w = lc->o_w; oproj_s = lc->o_s; oproj_b = lc->o_b; - oproj_in_dim = NUM_ATTN_HEADS * HEAD_DIM; + oproj_in_dim = cfg.num_attn_heads * cfg.head_dim; } else if (!linear_attn_bypass) { oproj_w = lc->out_proj_w; oproj_s = lc->out_proj_s; oproj_b = lc->out_proj_b; - oproj_in_dim = LINEAR_TOTAL_VALUE; + oproj_in_dim = cfg.linear_total_value; } // All MoE weight pointers from cache (zero snprintf overhead) @@ -4829,51 +4799,51 @@ static void fused_layer_forward( if (is_full) { // ---- Full attention CPU compute ---- - int q_proj_dim = NUM_ATTN_HEADS * HEAD_DIM * 2; - int q_dim = NUM_ATTN_HEADS * HEAD_DIM; - int kv_dim = NUM_KV_HEADS * HEAD_DIM; + int q_proj_dim = cfg.num_attn_heads * cfg.head_dim * 2; + int q_dim = cfg.num_attn_heads * cfg.head_dim; + int kv_dim = cfg.num_kv_heads * cfg.head_dim; (void)q_proj_dim; float *q = s_q; float *q_gate = s_q_gate; - for (int h = 0; h < NUM_ATTN_HEADS; h++) { - float *src = q_proj_out + h * (2 * HEAD_DIM); - memcpy(q + h * HEAD_DIM, src, HEAD_DIM * sizeof(float)); - memcpy(q_gate + h * HEAD_DIM, src + HEAD_DIM, HEAD_DIM * sizeof(float)); + for (int h = 0; h < cfg.num_attn_heads; h++) { + float *src = q_proj_out + h * (2 * cfg.head_dim); + memcpy(q + h * cfg.head_dim, src, cfg.head_dim * sizeof(float)); + memcpy(q_gate + h * cfg.head_dim, src + cfg.head_dim, cfg.head_dim * sizeof(float)); } // Q/K RMSNorm uint16_t *qnorm_w = lc->q_norm_w; uint16_t *knorm_w = lc->k_norm_w; if (qnorm_w) { - for (int h = 0; h < NUM_ATTN_HEADS; h++) { - float *qh = q + h * HEAD_DIM; + for (int h = 0; h < cfg.num_attn_heads; h++) { + float *qh = q + h * cfg.head_dim; float sum_sq = 0.0f; - for (int i = 0; i < HEAD_DIM; i++) sum_sq += qh[i] * qh[i]; - float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS); - for (int i = 0; i < HEAD_DIM; i++) qh[i] = qh[i] * inv_rms * bf16_to_f32(qnorm_w[i]); + for (int i = 0; i < cfg.head_dim; i++) sum_sq += qh[i] * qh[i]; + float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps); + for (int i = 0; i < cfg.head_dim; i++) qh[i] = qh[i] * inv_rms * bf16_to_f32(qnorm_w[i]); } } if (knorm_w) { - for (int h = 0; h < NUM_KV_HEADS; h++) { - float *kh = k_out + h * HEAD_DIM; + for (int h = 0; h < cfg.num_kv_heads; h++) { + float *kh = k_out + h * cfg.head_dim; float sum_sq = 0.0f; - for (int i = 0; i < HEAD_DIM; i++) sum_sq += kh[i] * kh[i]; - float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS); - for (int i = 0; i < HEAD_DIM; i++) kh[i] = kh[i] * inv_rms * bf16_to_f32(knorm_w[i]); + for (int i = 0; i < cfg.head_dim; i++) sum_sq += kh[i] * kh[i]; + float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps); + for (int i = 0; i < cfg.head_dim; i++) kh[i] = kh[i] * inv_rms * bf16_to_f32(knorm_w[i]); } } // RoPE - apply_rotary_emb(q, k_out, pos, NUM_ATTN_HEADS, NUM_KV_HEADS, HEAD_DIM, ROTARY_DIM); + apply_rotary_emb(q, k_out, pos, cfg.num_attn_heads, cfg.num_kv_heads, cfg.head_dim, cfg.rotary_dim); // Update KV cache (CPU + GPU mirror) int cache_pos = kv->len; memcpy(kv->k_cache + cache_pos * kv_dim, k_out, kv_dim * sizeof(float)); memcpy(kv->v_cache + cache_pos * kv_dim, v_out, kv_dim * sizeof(float)); - int fa_idx = (layer_idx + 1) / FULL_ATTN_INTERVAL - 1; - if (g_metal && g_metal->attn_scores_pipe && fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS) { + int fa_idx = cfg.full_attn_index[layer_idx]; + if (g_metal && g_metal->attn_scores_pipe && fa_idx >= 0 && fa_idx < cfg.num_full_attn_layers) { memcpy((float *)[g_metal->buf_kv_k[fa_idx] contents] + cache_pos * kv_dim, k_out, kv_dim * sizeof(float)); memcpy((float *)[g_metal->buf_kv_v[fa_idx] contents] + cache_pos * kv_dim, @@ -4882,15 +4852,15 @@ static void fused_layer_forward( kv->len++; // Scaled dot-product attention (GQA) — GPU or CPU - int heads_per_kv = NUM_ATTN_HEADS / NUM_KV_HEADS; - float scale = 1.0f / sqrtf((float)HEAD_DIM); + int heads_per_kv = cfg.num_attn_heads / cfg.num_kv_heads; + float scale = 1.0f / sqrtf((float)cfg.head_dim); float *attn_out = s_attn_out; memset(attn_out, 0, q_dim * sizeof(float)); // GPU attention: defer dispatches to CMD2 (fused into single cmd buffer). // Only enabled when seq_len >= 32 (below that, CPU is faster). int gpu_attn_ready = (g_metal && g_metal->attn_scores_pipe && - fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS && + fa_idx >= 0 && fa_idx < cfg.num_full_attn_layers && kv->len >= 32 && kv->len < GPU_KV_SEQ); if (gpu_attn_ready) { @@ -4900,21 +4870,21 @@ static void fused_layer_forward( // attn_out_for_oproj will be set to NULL below — CMD2 reads buf_attn_out } else { // CPU fallback - for (int h = 0; h < NUM_ATTN_HEADS; h++) { + for (int h = 0; h < cfg.num_attn_heads; h++) { int kv_h = h / heads_per_kv; - float *qh = q + h * HEAD_DIM; + float *qh = q + h * cfg.head_dim; float *scores = malloc(kv->len * sizeof(float)); for (int p = 0; p < kv->len; p++) { - float *kp = kv->k_cache + p * kv_dim + kv_h * HEAD_DIM; + float *kp = kv->k_cache + p * kv_dim + kv_h * cfg.head_dim; float dot = 0.0f; - for (int d = 0; d < HEAD_DIM; d++) dot += qh[d] * kp[d]; + for (int d = 0; d < cfg.head_dim; d++) dot += qh[d] * kp[d]; scores[p] = dot * scale; } cpu_softmax(scores, kv->len); - float *oh = attn_out + h * HEAD_DIM; + float *oh = attn_out + h * cfg.head_dim; for (int p = 0; p < kv->len; p++) { - float *vp = kv->v_cache + p * kv_dim + kv_h * HEAD_DIM; - for (int d = 0; d < HEAD_DIM; d++) oh[d] += scores[p] * vp[d]; + float *vp = kv->v_cache + p * kv_dim + kv_h * cfg.head_dim; + for (int d = 0; d < cfg.head_dim; d++) oh[d] += scores[p] * vp[d]; } free(scores); } @@ -4939,7 +4909,7 @@ static void fused_layer_forward( } else { // ---- Linear attention CPU compute ---- if (!linear_attn_bypass) { - int qkv_dim = LINEAR_CONV_DIM; + int qkv_dim = cfg.linear_conv_dim; // Conv1d step uint16_t *conv_w = lc->conv1d_w; @@ -4947,31 +4917,31 @@ static void fused_layer_forward( memset(conv_out, 0, qkv_dim * sizeof(float)); if (conv_w) { cpu_conv1d_step(la_state->conv_state, qkv_out, conv_w, conv_out, - qkv_dim, CONV_KERNEL_SIZE); + qkv_dim, cfg.conv_kernel_size); } // Update conv state memmove(la_state->conv_state, la_state->conv_state + qkv_dim, - (CONV_KERNEL_SIZE - 2) * qkv_dim * sizeof(float)); - memcpy(la_state->conv_state + (CONV_KERNEL_SIZE - 2) * qkv_dim, qkv_out, + (cfg.conv_kernel_size - 2) * qkv_dim * sizeof(float)); + memcpy(la_state->conv_state + (cfg.conv_kernel_size - 2) * qkv_dim, qkv_out, qkv_dim * sizeof(float)); // Split into q, k, v float *lin_q = conv_out; - float *lin_k = conv_out + LINEAR_TOTAL_KEY; - float *lin_v = conv_out + 2 * LINEAR_TOTAL_KEY; + float *lin_k = conv_out + cfg.linear_total_key; + float *lin_v = conv_out + 2 * cfg.linear_total_key; // RMS normalize q and k - float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM); - for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) { - float *qh = lin_q + h * LINEAR_KEY_DIM; - cpu_rms_norm_bare(qh, qh, LINEAR_KEY_DIM, 1e-6f); + float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim); + for (int h = 0; h < cfg.linear_num_k_heads; h++) { + float *qh = lin_q + h * cfg.linear_key_dim; + cpu_rms_norm_bare(qh, qh, cfg.linear_key_dim, 1e-6f); float q_scale = inv_scale * inv_scale; - for (int d = 0; d < LINEAR_KEY_DIM; d++) qh[d] *= q_scale; + for (int d = 0; d < cfg.linear_key_dim; d++) qh[d] *= q_scale; } - for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) { - float *kh = lin_k + h * LINEAR_KEY_DIM; - cpu_rms_norm_bare(kh, kh, LINEAR_KEY_DIM, 1e-6f); - for (int d = 0; d < LINEAR_KEY_DIM; d++) kh[d] *= inv_scale; + for (int h = 0; h < cfg.linear_num_k_heads; h++) { + float *kh = lin_k + h * cfg.linear_key_dim; + cpu_rms_norm_bare(kh, kh, cfg.linear_key_dim, 1e-6f); + for (int d = 0; d < cfg.linear_key_dim; d++) kh[d] *= inv_scale; } // Gated delta net recurrence @@ -4979,12 +4949,12 @@ static void fused_layer_forward( uint16_t *dt_bias_bf16 = lc->dt_bias; float *out_values = s_out_vals; - memset(out_values, 0, LINEAR_TOTAL_VALUE * sizeof(float)); - int k_heads_per_v = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS; + memset(out_values, 0, cfg.linear_total_value * sizeof(float)); + int k_heads_per_v = cfg.linear_num_v_heads / cfg.linear_num_k_heads; - float g_decay[LINEAR_NUM_V_HEADS]; - float beta_gate_arr[LINEAR_NUM_V_HEADS]; - for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) { + float g_decay[cfg.linear_num_v_heads]; + float beta_gate_arr[cfg.linear_num_v_heads]; + for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) { float a_val = alpha_out[vh]; float dt_b = dt_bias_bf16 ? bf16_to_f32(dt_bias_bf16[vh]) : 0.0f; float A_val = A_log ? expf(A_log[vh]) : 1.0f; @@ -4996,18 +4966,18 @@ static void fused_layer_forward( // Compute linear_layer_idx: count of non-full-attention layers before this one. // Full attention at (layer_idx+1) % 4 == 0, i.e. layers 3,7,11,... // linear_layer_idx = layer_idx - number_of_full_layers_at_or_before - // = layer_idx - (layer_idx + 1) / FULL_ATTN_INTERVAL - int linear_layer_idx = layer_idx - (layer_idx + 1) / FULL_ATTN_INTERVAL; + // = cfg.linear_index[layer_idx] + int linear_layer_idx = cfg.linear_index[layer_idx]; // GPU delta-net path (falls back to CPU if pipeline unavailable) if (g_metal && g_metal->delta_net_step && - linear_layer_idx >= 0 && linear_layer_idx < NUM_LINEAR_LAYERS) { + linear_layer_idx >= 0 && linear_layer_idx < cfg.num_linear_layers) { // Upload CPU-computed data to GPU scratch buffers - memcpy([g_metal->buf_delta_q contents], lin_q, LINEAR_TOTAL_KEY * sizeof(float)); - memcpy([g_metal->buf_delta_k contents], lin_k, LINEAR_TOTAL_KEY * sizeof(float)); - memcpy([g_metal->buf_delta_v contents], lin_v, LINEAR_TOTAL_VALUE * sizeof(float)); - memcpy([g_metal->buf_delta_g_decay contents], g_decay, LINEAR_NUM_V_HEADS * sizeof(float)); - memcpy([g_metal->buf_delta_beta contents], beta_gate_arr, LINEAR_NUM_V_HEADS * sizeof(float)); + memcpy([g_metal->buf_delta_q contents], lin_q, cfg.linear_total_key * sizeof(float)); + memcpy([g_metal->buf_delta_k contents], lin_k, cfg.linear_total_key * sizeof(float)); + memcpy([g_metal->buf_delta_v contents], lin_v, cfg.linear_total_value * sizeof(float)); + memcpy([g_metal->buf_delta_g_decay contents], g_decay, cfg.linear_num_v_heads * sizeof(float)); + memcpy([g_metal->buf_delta_beta contents], beta_gate_arr, cfg.linear_num_v_heads * sizeof(float)); id cmd_dn = [g_metal->queue commandBuffer]; id enc = [cmd_dn computeCommandEncoder]; @@ -5021,53 +4991,53 @@ static void fused_layer_forward( [enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:6]; uint32_t khpv = (uint32_t)k_heads_per_v; [enc setBytes:&khpv length:sizeof(khpv) atIndex:7]; - [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1) + [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [enc endEncoding]; [cmd_dn commit]; [cmd_dn waitUntilCompleted]; // Read back GPU result - memcpy(out_values, [g_metal->buf_delta_output contents], LINEAR_TOTAL_VALUE * sizeof(float)); + memcpy(out_values, [g_metal->buf_delta_output contents], cfg.linear_total_value * sizeof(float)); } else { // CPU delta-net with Accelerate BLAS - for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) { + for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) { int kh = vh / k_heads_per_v; float g = g_decay[vh]; float b_gate = beta_gate_arr[vh]; - float *S = la_state->ssm_state + vh * LINEAR_VALUE_DIM * LINEAR_KEY_DIM; - float *v_h = lin_v + vh * LINEAR_VALUE_DIM; - float *k_h = lin_k + kh * LINEAR_KEY_DIM; + float *S = la_state->ssm_state + vh * cfg.linear_value_dim * cfg.linear_key_dim; + float *v_h = lin_v + vh * cfg.linear_value_dim; + float *k_h = lin_k + kh * cfg.linear_key_dim; // Step 1: Decay S *= g (BLAS sscal on entire state matrix) - cblas_sscal(LINEAR_VALUE_DIM * LINEAR_KEY_DIM, g, S, 1); + cblas_sscal(cfg.linear_value_dim * cfg.linear_key_dim, g, S, 1); // Step 2: kv_mem = S @ k (each row dot k) // S is [VALUE_DIM x KEY_DIM] row-major, k is [KEY_DIM] // kv_mem[vi] = sum_ki(S[vi,ki] * k[ki]) = matrix-vector: S @ k - float kv_mem_vec[LINEAR_VALUE_DIM]; + float kv_mem_vec[cfg.linear_value_dim]; cblas_sgemv(CblasRowMajor, CblasNoTrans, - LINEAR_VALUE_DIM, LINEAR_KEY_DIM, - 1.0f, S, LINEAR_KEY_DIM, k_h, 1, + cfg.linear_value_dim, cfg.linear_key_dim, + 1.0f, S, cfg.linear_key_dim, k_h, 1, 0.0f, kv_mem_vec, 1); // Step 3: delta = (v - kv_mem) * beta, then rank-1 update S += k * delta^T // delta[vi] = (v[vi] - kv_mem[vi]) * beta - float delta_vec[LINEAR_VALUE_DIM]; - for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) { + float delta_vec[cfg.linear_value_dim]; + for (int vi = 0; vi < cfg.linear_value_dim; vi++) { delta_vec[vi] = (v_h[vi] - kv_mem_vec[vi]) * b_gate; } // S += delta @ k^T (rank-1 update: sger) // S[vi,ki] += delta[vi] * k[ki] - cblas_sger(CblasRowMajor, LINEAR_VALUE_DIM, LINEAR_KEY_DIM, - 1.0f, delta_vec, 1, k_h, 1, S, LINEAR_KEY_DIM); + cblas_sger(CblasRowMajor, cfg.linear_value_dim, cfg.linear_key_dim, + 1.0f, delta_vec, 1, k_h, 1, S, cfg.linear_key_dim); // Step 4: output = S @ q (matrix-vector multiply) - float *q_h = lin_q + kh * LINEAR_KEY_DIM; - float *o_h = out_values + vh * LINEAR_VALUE_DIM; + float *q_h = lin_q + kh * cfg.linear_key_dim; + float *o_h = out_values + vh * cfg.linear_value_dim; cblas_sgemv(CblasRowMajor, CblasNoTrans, - LINEAR_VALUE_DIM, LINEAR_KEY_DIM, - 1.0f, S, LINEAR_KEY_DIM, q_h, 1, + cfg.linear_value_dim, cfg.linear_key_dim, + 1.0f, S, cfg.linear_key_dim, q_h, 1, 0.0f, o_h, 1); } } @@ -5075,15 +5045,15 @@ static void fused_layer_forward( // RMSNormGated uint16_t *gated_norm_w = lc->gated_norm_w; float *gated_out = s_gated_out; - memset(gated_out, 0, LINEAR_TOTAL_VALUE * sizeof(float)); - for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) { - float *oh = out_values + vh * LINEAR_VALUE_DIM; - float *zh = z_out + vh * LINEAR_VALUE_DIM; - float *gh = gated_out + vh * LINEAR_VALUE_DIM; + memset(gated_out, 0, cfg.linear_total_value * sizeof(float)); + for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) { + float *oh = out_values + vh * cfg.linear_value_dim; + float *zh = z_out + vh * cfg.linear_value_dim; + float *gh = gated_out + vh * cfg.linear_value_dim; if (gated_norm_w) { - cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, LINEAR_VALUE_DIM, RMS_NORM_EPS); + cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, cfg.linear_value_dim, cfg.rms_norm_eps); } else { - memcpy(gh, oh, LINEAR_VALUE_DIM * sizeof(float)); + memcpy(gh, oh, cfg.linear_value_dim * sizeof(float)); } } @@ -5117,11 +5087,11 @@ static void fused_layer_forward( float *h_post = s_h_post; float *h_mid = s_h_mid; float *gate_scores = s_gate_scores; - memset(gate_scores, 0, NUM_EXPERTS * sizeof(float)); + memset(gate_scores, 0, cfg.num_experts * sizeof(float)); float *shared_gate = s_shared_gate; - memset(shared_gate, 0, SHARED_INTERMEDIATE * sizeof(float)); + memset(shared_gate, 0, cfg.shared_intermediate * sizeof(float)); float *shared_up = s_shared_up; - memset(shared_up, 0, SHARED_INTERMEDIATE * sizeof(float)); + memset(shared_up, 0, cfg.shared_intermediate * sizeof(float)); float shared_gate_score = 0.0f; int have_moe_weights = (gate_w && gate_s && gate_b && sgw && sgs && sgb && @@ -5160,7 +5130,7 @@ static void fused_layer_forward( } // gpu_linear_attn: batch_out[6] already has the result from CMD1 gated_rms_norm // Copy residual into GPU buffer for residual_add kernel - memcpy([g_metal->buf_residual contents], residual, HIDDEN_DIM * sizeof(float)); + memcpy([g_metal->buf_residual contents], residual, cfg.hidden_dim * sizeof(float)); attn_out_for_oproj = NULL; @@ -5168,11 +5138,11 @@ static void fused_layer_forward( // ---- GPU attention dispatches (only for full-attn layers with GPU path) ---- if (gpu_attn_fuse) { - int fa_idx = (layer_idx + 1) / FULL_ATTN_INTERVAL - 1; - int kv_dim = NUM_KV_HEADS * HEAD_DIM; - int heads_per_kv = NUM_ATTN_HEADS / NUM_KV_HEADS; - float scale = 1.0f / sqrtf((float)HEAD_DIM); - uint32_t hd = HEAD_DIM; + int fa_idx = cfg.full_attn_index[layer_idx]; + int kv_dim = cfg.num_kv_heads * cfg.head_dim; + int heads_per_kv = cfg.num_attn_heads / cfg.num_kv_heads; + float scale = 1.0f / sqrtf((float)cfg.head_dim); + uint32_t hd = cfg.head_dim; uint32_t kvd = (uint32_t)kv_dim; uint32_t sl = (uint32_t)kv->len; uint32_t seq_stride = GPU_KV_SEQ; @@ -5192,7 +5162,7 @@ static void fused_layer_forward( [enc setBytes:&scale length:4 atIndex:7]; [enc setBytes:&hpkv length:4 atIndex:8]; [enc setBytes:&sl length:4 atIndex:9]; - uint32_t total_tgs = sl * NUM_ATTN_HEADS; + uint32_t total_tgs = sl * cfg.num_attn_heads; [enc dispatchThreadgroups:MTLSizeMake(total_tgs, 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; [enc endEncoding]; @@ -5204,7 +5174,7 @@ static void fused_layer_forward( [enc setBuffer:g_metal->buf_attn_scores offset:0 atIndex:0]; [enc setBytes:&sl length:4 atIndex:1]; [enc setBytes:&seq_stride length:4 atIndex:2]; - [enc dispatchThreadgroups:MTLSizeMake(NUM_ATTN_HEADS, 1, 1) + [enc dispatchThreadgroups:MTLSizeMake(cfg.num_attn_heads, 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; [enc endEncoding]; } @@ -5220,7 +5190,7 @@ static void fused_layer_forward( [enc setBytes:&sl length:4 atIndex:5]; [enc setBytes:&seq_stride length:4 atIndex:6]; [enc setBytes:&hpkv length:4 atIndex:7]; - uint32_t total_threads = HEAD_DIM * NUM_ATTN_HEADS; + uint32_t total_threads = cfg.head_dim * cfg.num_attn_heads; uint32_t tgs = (total_threads + 255) / 256; [enc dispatchThreadgroups:MTLSizeMake(tgs, 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; @@ -5228,7 +5198,7 @@ static void fused_layer_forward( } // Enc A4: sigmoid_gate { - uint32_t qdim = NUM_ATTN_HEADS * HEAD_DIM; + uint32_t qdim = cfg.num_attn_heads * cfg.head_dim; id enc = [cmd_fused computeCommandEncoder]; [enc setComputePipelineState:g_metal->sigmoid_gate_pipe]; [enc setBuffer:g_metal->buf_attn_out offset:0 atIndex:0]; @@ -5252,9 +5222,9 @@ static void fused_layer_forward( id oproj_input = gpu_attn_fuse ? g_metal->buf_attn_out : g_metal->batch_out[6]; id enc = [cmd_fused computeCommandEncoder]; - uint32_t o_out_dim = HIDDEN_DIM; + uint32_t o_out_dim = cfg.hidden_dim; uint32_t o_in_dim = (uint32_t)oproj_in_dim; - uint32_t o_gs = GROUP_SIZE; + uint32_t o_gs = cfg.group_size; [enc setComputePipelineState:g_metal->matvec_fast]; [enc setBuffer:g_metal->wf_buf offset:w_off atIndex:0]; [enc setBuffer:g_metal->wf_buf offset:s_off atIndex:1]; @@ -5272,7 +5242,7 @@ static void fused_layer_forward( // ---- Enc 2: residual_add (buf_output + buf_residual -> buf_h_mid) ---- { id enc = [cmd_fused computeCommandEncoder]; - uint32_t dim = HIDDEN_DIM; + uint32_t dim = cfg.hidden_dim; [enc setComputePipelineState:g_metal->residual_add]; [enc setBuffer:g_metal->buf_residual offset:0 atIndex:0]; // a = residual [enc setBuffer:g_metal->buf_output offset:0 atIndex:1]; // b = o_proj result @@ -5287,7 +5257,7 @@ static void fused_layer_forward( // ---- Enc 3: rms_norm_sum_sq (buf_h_mid -> buf_sum_sq) ---- { id enc = [cmd_fused computeCommandEncoder]; - uint32_t dim = HIDDEN_DIM; + uint32_t dim = cfg.hidden_dim; [enc setComputePipelineState:g_metal->rms_norm_sum]; [enc setBuffer:g_metal->buf_h_mid offset:0 atIndex:0]; [enc setBuffer:g_metal->buf_sum_sq offset:0 atIndex:1]; @@ -5302,8 +5272,8 @@ static void fused_layer_forward( NSUInteger norm_off = (NSUInteger)((const char *)lc->post_attn_norm_w - (const char *)[g_metal->wf_buf contents]); id enc = [cmd_fused computeCommandEncoder]; - uint32_t dim = HIDDEN_DIM; - float eps = RMS_NORM_EPS; + uint32_t dim = cfg.hidden_dim; + float eps = cfg.rms_norm_eps; [enc setComputePipelineState:g_metal->rms_norm_apply_bf16]; [enc setBuffer:g_metal->buf_h_mid offset:0 atIndex:0]; // x [enc setBuffer:g_metal->wf_buf offset:norm_off atIndex:1]; // weight (bf16) @@ -5319,10 +5289,10 @@ static void fused_layer_forward( // ---- Enc 5-8: routing + shared expert projections (read buf_input) ---- BatchMatvecSpec moe_specs[4] = { - { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 }, - { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 }, - { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 }, - { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 }, + { gate_w, gate_s, gate_b, gate_scores, (uint32_t)cfg.num_experts, cfg.hidden_dim, cfg.group_size, 0 }, + { sgw, sgs, sgb, shared_gate, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 1 }, + { suw, sus, sub, shared_up, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 2 }, + { seg_w, seg_s, seg_b, &shared_gate_score, 1, cfg.hidden_dim, cfg.group_size, 3 }, }; // buf_input already contains h_post from Enc 4 output -- no memcpy needed gpu_encode_batch_matvec(g_metal, cmd_fused, moe_specs, 4); @@ -5337,11 +5307,11 @@ static void fused_layer_forward( // Read back results gpu_flush_batch_results(g_metal, moe_specs, 4); // Read h_mid from GPU buffer (needed for final combine) - memcpy(h_mid, [g_metal->buf_h_mid contents], HIDDEN_DIM * sizeof(float)); + memcpy(h_mid, [g_metal->buf_h_mid contents], cfg.hidden_dim * sizeof(float)); // Read h_post from buf_input (needed for expert input) - memcpy(h_post, [g_metal->buf_input contents], HIDDEN_DIM * sizeof(float)); + memcpy(h_post, [g_metal->buf_input contents], cfg.hidden_dim * sizeof(float)); // Update hidden state to h_mid (= residual + o_proj) - memcpy(hidden, h_mid, HIDDEN_DIM * sizeof(float)); + memcpy(hidden, h_mid, cfg.hidden_dim * sizeof(float)); if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd2_wait += t1 - t0; } } else { @@ -5349,45 +5319,45 @@ static void fused_layer_forward( // O projection if (attn_out_for_oproj && oproj_w && oproj_s && oproj_b) { fast_dequant_matvec(oproj_w, oproj_s, oproj_b, attn_out_for_oproj, - attn_projected, HIDDEN_DIM, oproj_in_dim, GROUP_SIZE); + attn_projected, cfg.hidden_dim, oproj_in_dim, cfg.group_size); } // attn_out_for_oproj is static — no free needed attn_out_for_oproj = NULL; // Residual connection - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { hidden[i] = residual[i] + attn_projected[i]; } // attn_projected, normed, residual are static — no free needed - cpu_vec_copy(h_mid, hidden, HIDDEN_DIM); + cpu_vec_copy(h_mid, hidden, cfg.hidden_dim); // Post-attention norm - cpu_rms_norm(hidden, lc->post_attn_norm_w, h_post, HIDDEN_DIM, RMS_NORM_EPS); + cpu_rms_norm(hidden, lc->post_attn_norm_w, h_post, cfg.hidden_dim, cfg.rms_norm_eps); // Routing + shared expert batch if (have_moe_weights) { BatchMatvecSpec moe_specs[4] = { - { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 }, - { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 }, - { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 }, - { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 }, + { gate_w, gate_s, gate_b, gate_scores, (uint32_t)cfg.num_experts, cfg.hidden_dim, cfg.group_size, 0 }, + { sgw, sgs, sgb, shared_gate, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 1 }, + { suw, sus, sub, shared_up, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 2 }, + { seg_w, seg_s, seg_b, &shared_gate_score, 1, cfg.hidden_dim, cfg.group_size, 3 }, }; - fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4); + fast_batch_matvec(h_post, cfg.hidden_dim, moe_specs, 4); } if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd2_encode += t1 - t0; } } // ---- Softmax + top-K (CPU) ---- if (g_timing_enabled) { t0 = now_ms(); } - cpu_softmax(gate_scores, NUM_EXPERTS); + cpu_softmax(gate_scores, cfg.num_experts); int expert_indices[64]; float expert_weights[64]; - cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights); + cpu_topk(gate_scores, cfg.num_experts, K, expert_indices, expert_weights); cpu_normalize_weights(expert_weights, K); if (g_freq_tracking) { for (int k = 0; k < K; k++) { - g_expert_freq[layer_idx][expert_indices[k]]++; + FREQ(layer_idx, expert_indices[k])++; } if (layer_idx == 0) g_freq_total_tokens++; } @@ -5413,7 +5383,7 @@ static void fused_layer_forward( int32_t ki = (K > MAX_K) ? MAX_K : K; fwrite(&li, sizeof(int32_t), 1, g_routing_log); fwrite(&ki, sizeof(int32_t), 1, g_routing_log); - fwrite(hidden, sizeof(float), HIDDEN_DIM, g_routing_log); + fwrite(hidden, sizeof(float), cfg.hidden_dim, g_routing_log); fwrite(expert_indices, sizeof(int32_t), ki, g_routing_log); g_routing_log_samples++; } @@ -5421,9 +5391,9 @@ static void fused_layer_forward( // ---- Parallel pread + GPU experts ---- if (g_timing_enabled) { t0 = now_ms(); } float *moe_out = s_moe_out; - memset(moe_out, 0, HIDDEN_DIM * sizeof(float)); + memset(moe_out, 0, cfg.hidden_dim * sizeof(float)); float *shared_out = s_shared_out; - memset(shared_out, 0, HIDDEN_DIM * sizeof(float)); + memset(shared_out, 0, cfg.hidden_dim * sizeof(float)); int actual_K = (K > MAX_K) ? MAX_K : K; @@ -5558,8 +5528,8 @@ static void fused_layer_forward( for (int k = 0; k < actual_K; k++) { int found = 0; - for (int p = 0; p < g_pred_count[layer_idx]; p++) { - if (expert_indices[k] == g_pred_experts[layer_idx][p] && + for (int p = 0; p < PRED_COUNT(layer_idx); p++) { + if (expert_indices[k] == PRED_EXPERT(layer_idx, p) && g_async_pread.valid[p]) { // Hit! This expert was pre-loaded into buf_B[p] expert_bufs[k] = g_metal->buf_multi_expert_data_B[p]; @@ -5627,11 +5597,11 @@ static void fused_layer_forward( } // Shared expert prep (doesn't need expert data — can overlap with async pread) - memcpy([g_metal->buf_multi_expert_input contents], h_post, HIDDEN_DIM * sizeof(float)); + memcpy([g_metal->buf_multi_expert_input contents], h_post, cfg.hidden_dim * sizeof(float)); memcpy([g_metal->buf_shared_gate contents], shared_gate, - SHARED_INTERMEDIATE * sizeof(float)); + cfg.shared_intermediate * sizeof(float)); memcpy([g_metal->buf_shared_up contents], shared_up, - SHARED_INTERMEDIATE * sizeof(float)); + cfg.shared_intermediate * sizeof(float)); // Wait for non-prediction async pread to complete if (!pred_started && g_async_pread.active) { @@ -5647,10 +5617,10 @@ static void fused_layer_forward( // MUST happen AFTER the prediction hit check above (which reads g_pred_experts). if (g_pred_enabled && g_pred_generating) { for (int k = 0; k < actual_K; k++) { - g_pred_experts[layer_idx][k] = expert_indices[k]; + PRED_EXPERT(layer_idx, k) = expert_indices[k]; } - g_pred_count[layer_idx] = actual_K; - if (layer_idx == NUM_LAYERS - 1) { + PRED_COUNT(layer_idx) = actual_K; + if (layer_idx == cfg.num_layers - 1) { g_pred_valid = 1; } } @@ -5674,7 +5644,7 @@ static void fused_layer_forward( [enc setBuffer:g_metal->buf_shared_gate offset:0 atIndex:0]; [enc setBuffer:g_metal->buf_shared_up offset:0 atIndex:1]; [enc setBuffer:g_metal->buf_shared_act offset:0 atIndex:2]; - uint32_t dim = SHARED_INTERMEDIATE; + uint32_t dim = cfg.shared_intermediate; [enc setBytes:&dim length:4 atIndex:3]; uint32_t swiglu_tgs = (dim + 255) / 256; [enc dispatchThreadgroups:MTLSizeMake(swiglu_tgs, 1, 1) @@ -5687,7 +5657,7 @@ static void fused_layer_forward( gpu_encode_dequant_matvec_with_io_bufs( g_metal, cmd_experts, sdw, sds, sdb, g_metal->buf_shared_act, g_metal->buf_shared_out, - HIDDEN_DIM, SHARED_INTERMEDIATE, GROUP_SIZE); + cfg.hidden_dim, cfg.shared_intermediate, cfg.group_size); } // Step 4: GPU-side combine + residual + norm (if not last layer) @@ -5706,7 +5676,7 @@ static void fused_layer_forward( g_metal->rms_norm_sum && g_metal->rms_norm_apply_bf16 && g_metal->wf_buf && - layer_idx < NUM_LAYERS - 1 && + layer_idx < cfg.num_layers - 1 && layer_cache[layer_idx + 1].input_norm_w != NULL); if (gpu_combine) { @@ -5736,7 +5706,7 @@ static void fused_layer_forward( [enc setBuffer:g_metal->buf_multi_expert_out[k] offset:0 atIndex:(3 + k)]; } [enc setBuffer:g_metal->buf_combine_params offset:0 atIndex:11]; // params - uint32_t dim = HIDDEN_DIM; + uint32_t dim = cfg.hidden_dim; uint32_t k_val = (uint32_t)actual_K; [enc setBytes:&dim length:4 atIndex:12]; [enc setBytes:&k_val length:4 atIndex:13]; @@ -5749,7 +5719,7 @@ static void fused_layer_forward( // Enc C2: rms_norm_sum_sq (buf_moe_hidden -> buf_cmd3_sum_sq) { id enc = [cmd_experts computeCommandEncoder]; - uint32_t dim = HIDDEN_DIM; + uint32_t dim = cfg.hidden_dim; [enc setComputePipelineState:g_metal->rms_norm_sum]; [enc setBuffer:g_metal->buf_moe_hidden offset:0 atIndex:0]; [enc setBuffer:g_metal->buf_cmd3_sum_sq offset:0 atIndex:1]; @@ -5765,8 +5735,8 @@ static void fused_layer_forward( NSUInteger norm_off = (NSUInteger)((const char *)next_norm_w - (const char *)[g_metal->wf_buf contents]); id enc = [cmd_experts computeCommandEncoder]; - uint32_t dim = HIDDEN_DIM; - float eps = RMS_NORM_EPS; + uint32_t dim = cfg.hidden_dim; + float eps = cfg.rms_norm_eps; [enc setComputePipelineState:g_metal->rms_norm_apply_bf16]; [enc setBuffer:g_metal->buf_moe_hidden offset:0 atIndex:0]; // x [enc setBuffer:g_metal->wf_buf offset:norm_off atIndex:1]; // weight (bf16) @@ -5800,7 +5770,7 @@ static void fused_layer_forward( g_deferred.layer_idx = layer_idx; if (!gpu_combine) { // Only need to save h_mid for CPU-side combine path - memcpy(g_deferred.h_mid, h_mid, HIDDEN_DIM * sizeof(float)); + memcpy(g_deferred.h_mid, h_mid, cfg.hidden_dim * sizeof(float)); } for (int k = 0; k < actual_K; k++) { g_deferred.expert_weights[k] = expert_weights[k]; @@ -5815,7 +5785,7 @@ static void fused_layer_forward( } else if (packed_fd >= 0) { // CPU fallback for experts size_t esz = active_expert_size(); - float *expert_out_cpu = malloc(HIDDEN_DIM * sizeof(float)); + float *expert_out_cpu = malloc(cfg.hidden_dim * sizeof(float)); for (int k = 0; k < K; k++) { int eidx = expert_indices[k]; off_t expert_offset = (off_t)eidx * esz; @@ -5830,63 +5800,63 @@ static void fused_layer_forward( // CPU fallback offsets — use 4-bit layout (2-bit CPU path not yet implemented) uint32_t *gw = (uint32_t *)expert_data; - uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : GATE_S_OFF_4)); - uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : GATE_B_OFF_4)); - uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : UP_W_OFF_4)); - uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : UP_S_OFF_4)); - uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : UP_B_OFF_4)); - uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : DOWN_W_OFF_4)); - uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : DOWN_S_OFF_4)); - uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : DOWN_B_OFF_4)); - - float *gate_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); - float *up_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float)); - float *act_out = malloc(MOE_INTERMEDIATE * sizeof(float)); + uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_s_off_2 : cfg.gate_s_off_4)); + uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_b_off_2 : cfg.gate_b_off_4)); + uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.up_w_off_2 : cfg.up_w_off_4)); + uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_s_off_2 : cfg.up_s_off_4)); + uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_b_off_2 : cfg.up_b_off_4)); + uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.down_w_off_2 : cfg.down_w_off_4)); + uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_s_off_2 : cfg.down_s_off_4)); + uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_b_off_2 : cfg.down_b_off_4)); + + float *gate_proj_out = malloc(cfg.moe_intermediate * sizeof(float)); + float *up_proj_out = malloc(cfg.moe_intermediate * sizeof(float)); + float *act_out = malloc(cfg.moe_intermediate * sizeof(float)); cpu_dequant_matvec(gw, gs_p, gb_p, h_post, gate_proj_out, - MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); + cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size); cpu_dequant_matvec(uw, us_p, ub_p, h_post, up_proj_out, - MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE); - cpu_swiglu(gate_proj_out, up_proj_out, act_out, MOE_INTERMEDIATE); + cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size); + cpu_swiglu(gate_proj_out, up_proj_out, act_out, cfg.moe_intermediate); cpu_dequant_matvec(dw, ds_p, db_p, act_out, expert_out_cpu, - HIDDEN_DIM, MOE_INTERMEDIATE, GROUP_SIZE); + cfg.hidden_dim, cfg.moe_intermediate, cfg.group_size); free(gate_proj_out); free(up_proj_out); free(act_out); free(expert_data); - cpu_vec_madd(moe_out, expert_out_cpu, expert_weights[k], HIDDEN_DIM); + cpu_vec_madd(moe_out, expert_out_cpu, expert_weights[k], cfg.hidden_dim); } free(expert_out_cpu); // CPU shared expert - float *shared_act = calloc(SHARED_INTERMEDIATE, sizeof(float)); - cpu_swiglu(shared_gate, shared_up, shared_act, SHARED_INTERMEDIATE); + float *shared_act = calloc(cfg.shared_intermediate, sizeof(float)); + cpu_swiglu(shared_gate, shared_up, shared_act, cfg.shared_intermediate); if (sdw && sds && sdb) { cpu_dequant_matvec(sdw, sds, sdb, shared_act, shared_out, - HIDDEN_DIM, SHARED_INTERMEDIATE, GROUP_SIZE); + cfg.hidden_dim, cfg.shared_intermediate, cfg.group_size); } free(shared_act); } else { // No experts available -- still need shared expert - float *shared_act = calloc(SHARED_INTERMEDIATE, sizeof(float)); - cpu_swiglu(shared_gate, shared_up, shared_act, SHARED_INTERMEDIATE); + float *shared_act = calloc(cfg.shared_intermediate, sizeof(float)); + cpu_swiglu(shared_gate, shared_up, shared_act, cfg.shared_intermediate); if (sdw && sds && sdb) { fast_dequant_matvec(sdw, sds, sdb, shared_act, shared_out, - HIDDEN_DIM, SHARED_INTERMEDIATE, GROUP_SIZE); + cfg.hidden_dim, cfg.shared_intermediate, cfg.group_size); } free(shared_act); } // ---- Shared expert gate ---- float shared_weight = cpu_sigmoid(shared_gate_score); - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { shared_out[i] *= shared_weight; } // ---- Final combine: hidden = h_mid + moe_out + shared_out ---- - for (int i = 0; i < HIDDEN_DIM; i++) { + for (int i = 0; i < cfg.hidden_dim; i++) { hidden[i] = h_mid[i] + moe_out[i] + shared_out[i]; } @@ -5925,14 +5895,14 @@ static void freq_print_analysis(int K) { // Per-layer analysis int experts_for_80_total = 0; // sum across layers for overall estimate - for (int l = 0; l < NUM_LAYERS; l++) { + for (int l = 0; l < cfg.num_layers; l++) { // Count unique experts and sort frequencies descending - int sorted[NUM_EXPERTS]; - memcpy(sorted, g_expert_freq[l], NUM_EXPERTS * sizeof(int)); - qsort(sorted, NUM_EXPERTS, sizeof(int), freq_cmp_desc); + int sorted[cfg.num_experts]; + memcpy(sorted, &FREQ(l, 0), cfg.num_experts * sizeof(int)); + qsort(sorted, cfg.num_experts, sizeof(int), freq_cmp_desc); int unique = 0; - for (int e = 0; e < NUM_EXPERTS; e++) { + for (int e = 0; e < cfg.num_experts; e++) { if (sorted[e] > 0) unique++; } @@ -5940,7 +5910,7 @@ static void freq_print_analysis(int K) { int cum = 0; int top10_cov = 0, top30_cov = 0, top60_cov = 0; int n_for_50 = 0, n_for_80 = 0, n_for_90 = 0; - for (int e = 0; e < NUM_EXPERTS; e++) { + for (int e = 0; e < cfg.num_experts; e++) { cum += sorted[e]; if (e == 9) top10_cov = cum; if (e == 29) top30_cov = cum; @@ -5966,10 +5936,10 @@ static void freq_print_analysis(int K) { } // Overall summary: average experts needed for 80% across all layers - double avg_experts_80 = (double)experts_for_80_total / NUM_LAYERS; + double avg_experts_80 = (double)experts_for_80_total / cfg.num_layers; // Expert size in GB: each expert is active_expert_size() bytes double expert_gb = (double)active_expert_size() / (1024.0 * 1024.0 * 1024.0); - double total_pin_gb = avg_experts_80 * NUM_LAYERS * expert_gb; + double total_pin_gb = avg_experts_80 * cfg.num_layers * expert_gb; fprintf(stderr, "\n--- Overall Summary ---\n"); fprintf(stderr, "To achieve 80%% hit rate across all layers, need %d experts pinned " @@ -5977,7 +5947,7 @@ static void freq_print_analysis(int K) { experts_for_80_total, avg_experts_80, total_pin_gb); fprintf(stderr, "Expert size: %zu bytes (%.3f MB), %d layers x %d experts = %d total\n", active_expert_size(), (double)active_expert_size() / (1024.0 * 1024.0), - NUM_LAYERS, NUM_EXPERTS, NUM_LAYERS * NUM_EXPERTS); + cfg.num_layers, cfg.num_experts, cfg.num_layers * cfg.num_experts); } #ifndef CHAT_MODE @@ -6297,17 +6267,17 @@ static void sse_send_done(int fd, const char *request_id) { static void sync_cpu_to_gpu_delta_state_serve(void **layer_states) { if (!g_metal || !g_metal->delta_net_step || !layer_states) return; int li = 0; - for (int i = 0; i < NUM_LAYERS; i++) { - if ((i + 1) % FULL_ATTN_INTERVAL == 0) continue; + for (int i = 0; i < cfg.num_layers; i++) { + if (cfg.is_full_attn[i]) continue; if (!layer_states[i]) { li++; continue; } LinearAttnState *la = (LinearAttnState *)layer_states[i]; - if (li < NUM_LINEAR_LAYERS) { + if (li < cfg.num_linear_layers) { if (g_metal->buf_delta_state[li] && la->ssm_state) memcpy([g_metal->buf_delta_state[li] contents], la->ssm_state, - LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM * sizeof(float)); + cfg.linear_num_v_heads * cfg.linear_value_dim * cfg.linear_key_dim * sizeof(float)); if (g_metal->buf_conv_state[li] && la->conv_state) memcpy([g_metal->buf_conv_state[li] contents], la->conv_state, - (CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM * sizeof(float)); + (cfg.conv_kernel_size - 1) * cfg.linear_conv_dim * sizeof(float)); } li++; } @@ -6359,22 +6329,22 @@ static void serve_loop( // Pre-embed all system prompt tokens float *sys_embed_batch = NULL; if (sys_pt->count > 1) { - sys_embed_batch = malloc((size_t)sys_pt->count * HIDDEN_DIM * sizeof(float)); + sys_embed_batch = malloc((size_t)sys_pt->count * cfg.hidden_dim * sizeof(float)); for (int i = 0; i < sys_pt->count; i++) { - embed_lookup(wf, sys_pt->ids[i], sys_embed_batch + (size_t)i * HIDDEN_DIM); + embed_lookup(wf, sys_pt->ids[i], sys_embed_batch + (size_t)i * cfg.hidden_dim); } } // Intermediate system prompt tokens: discard last-layer expert output for (int i = 0; i < sys_pt->count - 1; i++) { cache_telemetry_note_token(); if (sys_embed_batch) { - memcpy(hidden, sys_embed_batch + (size_t)i * HIDDEN_DIM, - HIDDEN_DIM * sizeof(float)); + memcpy(hidden, sys_embed_batch + (size_t)i * cfg.hidden_dim, + cfg.hidden_dim * sizeof(float)); } else { embed_lookup(wf, sys_pt->ids[i], hidden); } - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -6389,13 +6359,13 @@ static void serve_loop( { cache_telemetry_note_token(); if (sys_embed_batch) { - memcpy(hidden, sys_embed_batch + (size_t)(sys_pt->count - 1) * HIDDEN_DIM, - HIDDEN_DIM * sizeof(float)); + memcpy(hidden, sys_embed_batch + (size_t)(sys_pt->count - 1) * cfg.hidden_dim, + cfg.hidden_dim * sizeof(float)); } else { embed_lookup(wf, sys_pt->ids[0], hidden); } - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -6420,20 +6390,20 @@ static void serve_loop( float *v_snapshot; int len; } KVSnapshot; - KVSnapshot kv_snapshots[NUM_LAYERS]; + KVSnapshot kv_snapshots[cfg.num_layers]; memset(kv_snapshots, 0, sizeof(kv_snapshots)); // Linear attention snapshots - float *la_conv_snapshots[NUM_LAYERS]; - float *la_ssm_snapshots[NUM_LAYERS]; + float *la_conv_snapshots[cfg.num_layers]; + float *la_ssm_snapshots[cfg.num_layers]; memset(la_conv_snapshots, 0, sizeof(la_conv_snapshots)); memset(la_ssm_snapshots, 0, sizeof(la_ssm_snapshots)); - size_t kv_dim = NUM_KV_HEADS * HEAD_DIM; - size_t conv_state_size = (CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM * sizeof(float); - size_t ssm_state_size = LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM * sizeof(float); + size_t kv_dim = cfg.num_kv_heads * cfg.head_dim; + size_t conv_state_size = (cfg.conv_kernel_size - 1) * cfg.linear_conv_dim * sizeof(float); + size_t ssm_state_size = cfg.linear_num_v_heads * cfg.linear_value_dim * cfg.linear_key_dim * sizeof(float); - for (int i = 0; i < NUM_LAYERS; i++) { + for (int i = 0; i < cfg.num_layers; i++) { if (kv_caches[i]) { size_t sz = sys_pos * kv_dim * sizeof(float); kv_snapshots[i].k_snapshot = malloc(sz); @@ -6451,19 +6421,19 @@ static void serve_loop( } } // Also snapshot GPU delta-net state - void *gpu_delta_snapshots[NUM_LINEAR_LAYERS]; - void *gpu_conv_snapshots[NUM_LINEAR_LAYERS]; - memset(gpu_delta_snapshots, 0, sizeof(gpu_delta_snapshots)); - memset(gpu_conv_snapshots, 0, sizeof(gpu_conv_snapshots)); + void **gpu_delta_snapshots = calloc(cfg.num_linear_layers, sizeof(void *)); + void **gpu_conv_snapshots = calloc(cfg.num_linear_layers, sizeof(void *)); + // already zeroed by calloc + // already zeroed by calloc if (g_metal && g_metal->delta_net_step) { - for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { + for (int i = 0; i < cfg.num_linear_layers; i++) { if (g_metal->buf_delta_state[i]) { - size_t sz = (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float); + size_t sz = (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float); gpu_delta_snapshots[i] = malloc(sz); memcpy(gpu_delta_snapshots[i], [g_metal->buf_delta_state[i] contents], sz); } if (g_metal->buf_conv_state[i]) { - size_t sz = (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float); + size_t sz = (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float); gpu_conv_snapshots[i] = malloc(sz); memcpy(gpu_conv_snapshots[i], [g_metal->buf_conv_state[i] contents], sz); } @@ -6598,7 +6568,7 @@ static void serve_loop( // ---- Restore state from system prompt snapshot ---- // Instead of resetting to zero, restore to the cached system prompt state. // This skips re-prefilling the system prompt tokens (~20 tokens, ~6s saved). - for (int i = 0; i < NUM_LAYERS; i++) { + for (int i = 0; i < cfg.num_layers; i++) { if (kv_caches[i] && kv_snapshots[i].k_snapshot) { size_t sz = sys_prompt_len * kv_dim * sizeof(float); memcpy(kv_caches[i]->k_cache, kv_snapshots[i].k_snapshot, sz); @@ -6606,8 +6576,8 @@ static void serve_loop( kv_caches[i]->len = kv_snapshots[i].len; // Also restore GPU KV mirror if (g_metal) { - int fa_idx = (i + 1) / FULL_ATTN_INTERVAL - 1; - if (fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS) { + int fa_idx = cfg.full_attn_index[i]; + if (fa_idx >= 0 && fa_idx < cfg.num_full_attn_layers) { memcpy([g_metal->buf_kv_k[fa_idx] contents], kv_snapshots[i].k_snapshot, sz); memcpy([g_metal->buf_kv_v[fa_idx] contents], @@ -6629,13 +6599,13 @@ static void serve_loop( } // Restore GPU delta-net state if (g_metal && g_metal->delta_net_step) { - for (int i = 0; i < NUM_LINEAR_LAYERS; i++) { + for (int i = 0; i < cfg.num_linear_layers; i++) { if (gpu_delta_snapshots[i] && g_metal->buf_delta_state[i]) memcpy([g_metal->buf_delta_state[i] contents], - gpu_delta_snapshots[i], (size_t)LINEAR_NUM_V_HEADS*LINEAR_VALUE_DIM*LINEAR_KEY_DIM*sizeof(float)); + gpu_delta_snapshots[i], (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float)); if (gpu_conv_snapshots[i] && g_metal->buf_conv_state[i]) memcpy([g_metal->buf_conv_state[i] contents], - gpu_conv_snapshots[i], (CONV_KERNEL_SIZE-1)*(size_t)LINEAR_CONV_DIM*sizeof(float)); + gpu_conv_snapshots[i], (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float)); } } else { reset_delta_net_state(); @@ -6659,22 +6629,22 @@ static void serve_loop( // Pre-embed all request tokens float *serve_embed_batch = NULL; if (pt->count > 1) { - serve_embed_batch = malloc((size_t)pt->count * HIDDEN_DIM * sizeof(float)); + serve_embed_batch = malloc((size_t)pt->count * cfg.hidden_dim * sizeof(float)); for (int i = 0; i < pt->count; i++) { - embed_lookup(wf, pt->ids[i], serve_embed_batch + (size_t)i * HIDDEN_DIM); + embed_lookup(wf, pt->ids[i], serve_embed_batch + (size_t)i * cfg.hidden_dim); } } // Intermediate prefill tokens: discard last-layer expert output for (int i = 0; i < pt->count - 1; i++) { cache_telemetry_note_token(); if (serve_embed_batch) { - memcpy(hidden, serve_embed_batch + (size_t)i * HIDDEN_DIM, - HIDDEN_DIM * sizeof(float)); + memcpy(hidden, serve_embed_batch + (size_t)i * cfg.hidden_dim, + cfg.hidden_dim * sizeof(float)); } else { embed_lookup(wf, pt->ids[i], hidden); } - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -6689,13 +6659,13 @@ static void serve_loop( { cache_telemetry_note_token(); if (serve_embed_batch) { - memcpy(hidden, serve_embed_batch + (size_t)(pt->count - 1) * HIDDEN_DIM, - HIDDEN_DIM * sizeof(float)); + memcpy(hidden, serve_embed_batch + (size_t)(pt->count - 1) * cfg.hidden_dim, + cfg.hidden_dim * sizeof(float)); } else { embed_lookup(wf, pt->ids[0], hidden); } - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -6713,13 +6683,13 @@ static void serve_loop( // ---- Final norm + LM head for first token ---- if (final_norm_w) { - float *normed = malloc(HIDDEN_DIM * sizeof(float)); - cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS); - memcpy(hidden, normed, HIDDEN_DIM * sizeof(float)); + float *normed = malloc(cfg.hidden_dim * sizeof(float)); + cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps); + memcpy(hidden, normed, cfg.hidden_dim * sizeof(float)); free(normed); } lm_head_forward(wf, hidden, logits); - int next_token = cpu_argmax(logits, VOCAB_SIZE); + int next_token = cpu_argmax(logits, cfg.vocab_size); // ---- Auto-regressive generation with SSE streaming ---- if (g_pred_enabled) { @@ -6735,12 +6705,12 @@ static void serve_loop( int gen_resp_len = 0; for (int gen = 0; gen < max_gen; gen++) { - if (next_token == EOS_TOKEN_1 || next_token == EOS_TOKEN_2) { + if (next_token == cfg.eos_token_ids[0] || next_token == cfg.eos_token_ids[1]) { // Feed EOS through the model so session state includes it cache_telemetry_note_token(); embed_lookup(wf, next_token, hidden); - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -6754,12 +6724,12 @@ static void serve_loop( } // Think budget enforcement - if (next_token == THINK_START_TOKEN) in_think = 1; - if (next_token == THINK_END_TOKEN) in_think = 0; + if (next_token == cfg.think_start_token) in_think = 1; + if (next_token == cfg.think_end_token) in_think = 0; if (in_think) { think_tokens++; if (g_think_budget > 0 && think_tokens >= g_think_budget) { - next_token = THINK_END_TOKEN; // force end thinking + next_token = cfg.think_end_token; // force end thinking in_think = 0; } } @@ -6781,8 +6751,8 @@ static void serve_loop( // Generate next cache_telemetry_note_token(); embed_lookup(wf, next_token, hidden); - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -6794,13 +6764,13 @@ static void serve_loop( pos++; if (final_norm_w) { - float *normed = malloc(HIDDEN_DIM * sizeof(float)); - cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS); - memcpy(hidden, normed, HIDDEN_DIM * sizeof(float)); + float *normed = malloc(cfg.hidden_dim * sizeof(float)); + cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps); + memcpy(hidden, normed, cfg.hidden_dim * sizeof(float)); free(normed); } lm_head_forward(wf, hidden, logits); - next_token = cpu_argmax(logits, VOCAB_SIZE); + next_token = cpu_argmax(logits, cfg.vocab_size); } sse_send_done(client_fd, request_id); @@ -6874,7 +6844,7 @@ static void print_usage(const char *prog) { int main(int argc, char **argv) { @autoreleasepool { - const char *model_path = MODEL_PATH_DEFAULT; + const char *model_path = getenv("FLASH_MOE_MODEL"); const char *weights_path = NULL; const char *manifest_path = NULL; const char *vocab_path = NULL; @@ -6947,6 +6917,11 @@ int main(int argc, char **argv) { } } + // ---- Load model configuration from HF config.json ---- + load_model_config(model_path ? model_path : ""); + alloc_tracking_arrays(); + g_deferred.h_mid = calloc(cfg.hidden_dim, sizeof(float)); + // Build default paths char default_weights[1024], default_manifest[1024], default_vocab[1024]; @@ -7092,16 +7067,16 @@ int main(int argc, char **argv) { // Seen-expert bitset tracks which (layer, expert) pairs have been read before. // First read goes through cold fd (no page cache pollution). // Subsequent reads go through warm fd (page cache hit = 32 GB/s vs 5.5 GB/s). - int layer_fds[NUM_LAYERS]; - int layer_fds_cold[NUM_LAYERS]; - void *layer_mmaps[NUM_LAYERS]; - size_t layer_mmap_sizes[NUM_LAYERS]; + int *layer_fds = calloc(cfg.num_layers, sizeof(int)); + int *layer_fds_cold = calloc(cfg.num_layers, sizeof(int)); + void **layer_mmaps = calloc(cfg.num_layers, sizeof(void *)); + size_t *layer_mmap_sizes = calloc(cfg.num_layers, sizeof(size_t)); int expert_layers_available = 0; // Reset the global seen-expert bitset - memset(g_expert_seen, 0, sizeof(g_expert_seen)); + memset(g_expert_seen, 0, cfg.num_layers * ((cfg.num_experts + 7) / 8)); - for (int i = 0; i < NUM_LAYERS; i++) { + for (int i = 0; i < cfg.num_layers; i++) { char path[1024]; snprintf(path, sizeof(path), "%s/%s/layer_%02d.bin", model_path, g_use_2bit ? "packed_experts_2bit" : "packed_experts", i); @@ -7128,7 +7103,7 @@ int main(int argc, char **argv) { } } } - printf("[experts] %d/%d packed layer files available (mmap'd)\n", expert_layers_available, NUM_LAYERS); + printf("[experts] %d/%d packed layer files available (mmap'd)\n", expert_layers_available, cfg.num_layers); // ---- LZ4 compressed experts: auto-detect and load ---- { @@ -7136,16 +7111,16 @@ int main(int argc, char **argv) { snprintf(lz4_probe, sizeof(lz4_probe), "%s/packed_experts_lz4/layer_00.bin", model_path); if (!g_use_2bit && access(lz4_probe, R_OK) == 0) { int lz4_layers = 0; - for (int i = 0; i < NUM_LAYERS; i++) { + for (int i = 0; i < cfg.num_layers; i++) { char lz4_path[1024]; snprintf(lz4_path, sizeof(lz4_path), "%s/packed_experts_lz4/layer_%02d.bin", model_path, i); int lz4_fd = open(lz4_path, O_RDONLY); if (lz4_fd >= 0) { - // Load index header (NUM_EXPERTS entries × 16 bytes) - g_lz4_index[i] = malloc(NUM_EXPERTS * sizeof(LZ4IndexEntry)); + // Load index header (cfg.num_experts entries × 16 bytes) + g_lz4_index[i] = malloc(cfg.num_experts * sizeof(LZ4IndexEntry)); ssize_t nr = pread(lz4_fd, g_lz4_index[i], - NUM_EXPERTS * sizeof(LZ4IndexEntry), 0); - if (nr == NUM_EXPERTS * (ssize_t)sizeof(LZ4IndexEntry)) { + cfg.num_experts * sizeof(LZ4IndexEntry), 0); + if (nr == cfg.num_experts * (ssize_t)sizeof(LZ4IndexEntry)) { // Replace the raw fd with the LZ4 fd close(layer_fds[i]); layer_fds[i] = lz4_fd; @@ -7162,10 +7137,10 @@ int main(int argc, char **argv) { g_use_lz4 = 1; // Allocate compressed read buffers (one per expert slot) for (int k = 0; k < MAX_K; k++) { - g_lz4_comp_bufs[k] = malloc(EXPERT_SIZE + 4096); + g_lz4_comp_bufs[k] = malloc(cfg.expert_size_4bit + 4096); } printf("[lz4] %d/%d layers using LZ4 compressed experts\n", - lz4_layers, NUM_LAYERS); + lz4_layers, cfg.num_layers); } } } @@ -7178,7 +7153,7 @@ int main(int argc, char **argv) { // Warm page cache hint if (expert_layers_available > 0) { double t_warm = now_ms(); - for (int i = 0; i < NUM_LAYERS; i++) { + for (int i = 0; i < cfg.num_layers; i++) { if (layer_fds[i] >= 0) { char dummy[4096]; pread(layer_fds[i], dummy, sizeof(dummy), 0); @@ -7188,11 +7163,11 @@ int main(int argc, char **argv) { } // ---- Allocate per-layer state ---- - void **layer_states = calloc(NUM_LAYERS, sizeof(void *)); - KVCache **kv_caches = calloc(NUM_LAYERS, sizeof(KVCache *)); + void **layer_states = calloc(cfg.num_layers, sizeof(void *)); + KVCache **kv_caches = calloc(cfg.num_layers, sizeof(KVCache *)); - for (int i = 0; i < NUM_LAYERS; i++) { - int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0); + for (int i = 0; i < cfg.num_layers; i++) { + int is_full = cfg.is_full_attn[i]; if (is_full) { kv_caches[i] = kv_cache_new(); } else { @@ -7204,8 +7179,8 @@ int main(int argc, char **argv) { printf("[init] Setup: %.1f ms\n\n", t_init - t0); // ---- Allocate working buffers ---- - float *hidden = calloc(HIDDEN_DIM, sizeof(float)); - float *logits = calloc(VOCAB_SIZE, sizeof(float)); + float *hidden = calloc(cfg.hidden_dim, sizeof(float)); + float *logits = calloc(cfg.vocab_size, sizeof(float)); uint16_t *final_norm_w = get_tensor_ptr(wf, "model.norm.weight"); // ---- Serve mode: enter HTTP server loop (never returns) ---- @@ -7231,10 +7206,10 @@ int main(int argc, char **argv) { // embed_lookup with GPU work, and enables the optimized prefill loop below. float *embed_batch = NULL; if (pt->count > 1) { - embed_batch = malloc((size_t)pt->count * HIDDEN_DIM * sizeof(float)); + embed_batch = malloc((size_t)pt->count * cfg.hidden_dim * sizeof(float)); double t_embed = now_ms(); for (int i = 0; i < pt->count; i++) { - embed_lookup(wf, pt->ids[i], embed_batch + (size_t)i * HIDDEN_DIM); + embed_lookup(wf, pt->ids[i], embed_batch + (size_t)i * cfg.hidden_dim); } double embed_ms = now_ms() - t_embed; printf(" [prefill] batch embed %d tokens: %.1f ms\n", pt->count, embed_ms); @@ -7256,12 +7231,12 @@ int main(int argc, char **argv) { // Load pre-embedded token from batch buffer cache_telemetry_note_token(); - memcpy(hidden, embed_batch + (size_t)token_idx * HIDDEN_DIM, - HIDDEN_DIM * sizeof(float)); + memcpy(hidden, embed_batch + (size_t)token_idx * cfg.hidden_dim, + cfg.hidden_dim * sizeof(float)); // Run through all 60 transformer layers - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -7292,14 +7267,14 @@ int main(int argc, char **argv) { { cache_telemetry_note_token(); if (embed_batch) { - memcpy(hidden, embed_batch + (size_t)(pt->count - 1) * HIDDEN_DIM, - HIDDEN_DIM * sizeof(float)); + memcpy(hidden, embed_batch + (size_t)(pt->count - 1) * cfg.hidden_dim, + cfg.hidden_dim * sizeof(float)); } else { embed_lookup(wf, pt->ids[0], hidden); } - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -7316,9 +7291,9 @@ int main(int argc, char **argv) { // ---- Final norm ---- if (final_norm_w) { - float *normed = malloc(HIDDEN_DIM * sizeof(float)); - cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS); - memcpy(hidden, normed, HIDDEN_DIM * sizeof(float)); + float *normed = malloc(cfg.hidden_dim * sizeof(float)); + cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps); + memcpy(hidden, normed, cfg.hidden_dim * sizeof(float)); free(normed); } @@ -7328,7 +7303,7 @@ int main(int argc, char **argv) { double lm_ms = now_ms() - t_lm; // ---- Sample first token ---- - int next_token = cpu_argmax(logits, VOCAB_SIZE); + int next_token = cpu_argmax(logits, cfg.vocab_size); double ttft_ms = now_ms() - t0; // Debug: show top-5 logits for first token @@ -7336,7 +7311,7 @@ int main(int argc, char **argv) { // Find top 5 manually int top5[5] = {0,0,0,0,0}; float topv[5] = {-1e30f,-1e30f,-1e30f,-1e30f,-1e30f}; - for (int i = 0; i < VOCAB_SIZE; i++) { + for (int i = 0; i < cfg.vocab_size; i++) { int min_k = 0; for (int k = 1; k < 5; k++) if (topv[k] < topv[min_k]) min_k = k; if (logits[i] > topv[min_k]) { topv[min_k] = logits[i]; top5[min_k] = i; } @@ -7347,7 +7322,7 @@ int main(int argc, char **argv) { top5[i], decode_token(vocab, top5[i]), topv[i]); } fprintf(stderr, "[debug] hidden rms after final_norm=%.4f, logits rms=%.4f\n", - vec_rms(hidden, HIDDEN_DIM), vec_rms(logits, VOCAB_SIZE)); + vec_rms(hidden, cfg.hidden_dim), vec_rms(logits, cfg.vocab_size)); } printf("[ttft] %.0f ms (prefill %d tokens + lm_head %.0f ms)\n", ttft_ms, pt->count, lm_ms); @@ -7357,7 +7332,7 @@ int main(int argc, char **argv) { fflush(stdout); int total_generated = 1; - int in_think = (next_token == THINK_START_TOKEN) ? 1 : 0; + int in_think = (next_token == cfg.think_start_token) ? 1 : 0; int think_tokens = 0; // ---- Auto-regressive generation ---- @@ -7370,14 +7345,14 @@ int main(int argc, char **argv) { double t_gen_start = now_ms(); // Check EOS - if (next_token == EOS_TOKEN_1 || next_token == EOS_TOKEN_2) { + if (next_token == cfg.eos_token_ids[0] || next_token == cfg.eos_token_ids[1]) { fprintf(stderr, "\n[eos] Token %d at position %d\n", next_token, gen); break; } // Think budget enforcement - if (next_token == THINK_START_TOKEN) in_think = 1; - if (next_token == THINK_END_TOKEN) in_think = 0; + if (next_token == cfg.think_start_token) in_think = 1; + if (next_token == cfg.think_end_token) in_think = 0; if (in_think) think_tokens++; // Embed the just-generated token (next iteration) @@ -7385,8 +7360,8 @@ int main(int argc, char **argv) { embed_lookup(wf, next_token, hidden); // Run 40 layers (fused: 1+K cmd buffers per layer) - for (int layer = 0; layer < NUM_LAYERS; layer++) { - int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0); + for (int layer = 0; layer < cfg.num_layers; layer++) { + int is_full = cfg.is_full_attn[layer]; fused_layer_forward(wf, layer, hidden, is_full ? kv_caches[layer] : NULL, is_full ? NULL : layer_states[layer], @@ -7400,9 +7375,9 @@ int main(int argc, char **argv) { // Final norm if (final_norm_w) { - float *normed = malloc(HIDDEN_DIM * sizeof(float)); - cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS); - memcpy(hidden, normed, HIDDEN_DIM * sizeof(float)); + float *normed = malloc(cfg.hidden_dim * sizeof(float)); + cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps); + memcpy(hidden, normed, cfg.hidden_dim * sizeof(float)); free(normed); } @@ -7410,11 +7385,11 @@ int main(int argc, char **argv) { lm_head_forward(wf, hidden, logits); // Greedy sample - next_token = cpu_argmax(logits, VOCAB_SIZE); + next_token = cpu_argmax(logits, cfg.vocab_size); // Think budget: force end thinking if over budget if (in_think && g_think_budget > 0 && think_tokens >= g_think_budget) { - next_token = THINK_END_TOKEN; + next_token = cfg.think_end_token; in_think = 0; } total_generated++; @@ -7442,7 +7417,7 @@ int main(int argc, char **argv) { printf("Generation: %.1f s (%.2f tok/s)\n", gen_time / 1000.0, (total_generated - 1) * 1000.0 / gen_time); } - printf("Config: K=%d experts, %d layers\n", K, NUM_LAYERS); + printf("Config: K=%d experts, %d layers\n", K, cfg.num_layers); if (g_expert_cache) { uint64_t total = g_expert_cache->hits + g_expert_cache->misses; printf("Expert cache: %llu hits, %llu misses (%.1f%% hit rate), %d/%d entries used\n", @@ -7484,7 +7459,7 @@ int main(int argc, char **argv) { expert_cache_free(g_expert_cache); g_expert_cache = NULL; } - for (int i = 0; i < NUM_LAYERS; i++) { + for (int i = 0; i < cfg.num_layers; i++) { if (kv_caches[i]) kv_cache_free(kv_caches[i]); if (layer_states[i]) linear_attn_state_free(layer_states[i]); if (layer_mmaps[i] != MAP_FAILED) munmap(layer_mmaps[i], layer_mmap_sizes[i]); From 1ae86beeeddff3db1b0879240d32f93c1039cbcd Mon Sep 17 00:00:00 2001 From: Alessio Delmonti Date: Fri, 20 Mar 2026 22:53:13 +0000 Subject: [PATCH 5/7] chore: update header and banner for runtime model config Generalize file header comment to describe multi-model support. Update startup banner from hardcoded model name to "Flash-MoE" with dynamic config path display. Co-Authored-By: Claude Opus 4.6 --- metal_infer/infer.m | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/metal_infer/infer.m b/metal_infer/infer.m index ef70d9d..afa69f9 100644 --- a/metal_infer/infer.m +++ b/metal_infer/infer.m @@ -1,16 +1,18 @@ /* - * infer.m — Complete Qwen3.5-35B-A3B inference engine using Metal + * infer.m — Qwen3.5 MoE inference engine using Metal * - * Full forward pass: embedding -> 40 transformer layers -> norm -> lm_head -> sample + * Full forward pass: embedding -> N transformer layers -> norm -> lm_head -> sample + * Model architecture loaded at runtime from HuggingFace config.json (--model flag). * Non-expert weights loaded from model_weights.bin (mmap'd at startup) * Expert weights loaded from packed_experts/ per layer per token (pread) * - * Architecture: Qwen3.5-35B-A3B (MoE) - * - 40 layers: 30 linear attention (GatedDeltaNet) + 10 full attention - * - hidden_size=2048, head_dim=256, num_attention_heads=16, num_kv_heads=2 - * - 256 experts/layer, 8 active (K=8) + * Supported: Qwen3.5-35B-A3B, Qwen3.5-397B-A17B, and compatible MoE variants. + * Architecture auto-detected from config.json: + * - N layers: mix of linear attention (GatedDeltaNet) + full attention + * - Configurable hidden_size, head_dim, num_attention_heads, num_kv_heads + * - Variable experts/layer and active experts (K) * - Shared expert per layer (always active) - * - Linear attention: conv1d(kernel=4) + gated delta recurrence + * - Linear attention: conv1d + gated delta recurrence * - Full attention: standard QKV + scaled dot product + RoPE * * Command buffer optimization (fused_layer_forward): @@ -6974,7 +6976,8 @@ int main(int argc, char **argv) { g_expert_cache = expert_cache_new(g_metal->device, cache_entries); } - printf("=== Qwen3.5-35B-A3B Metal Inference Engine ===\n"); + printf("=== Flash-MoE Metal Inference Engine ===\n"); + printf("Config: %s/config.json\n", cfg.model_path); printf("Model: %s\n", model_path); printf("Weights: %s\n", weights_path); printf("Manifest: %s\n", manifest_path); From f19e5634733cc4141519d1fd801c74f93c8787ee Mon Sep 17 00:00:00 2001 From: Alessio Delmonti Date: Fri, 20 Mar 2026 23:06:30 +0000 Subject: [PATCH 6/7] feat: add model manager utility for listing and downloading compatible models Lists local HF-cached models with compatibility check, searches HuggingFace for compatible Qwen3.5 MoE models (35B-A3B, 122B-A10B, 397B-A17B) with MLX quantization, and supports downloading via huggingface-cli or huggingface_hub. Usage: python model_manager.py # list local + remote python model_manager.py --local # local only python model_manager.py --search # remote only python model_manager.py --download python model_manager.py --check Co-Authored-By: Claude Opus 4.6 --- model_manager.py | 398 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100755 model_manager.py diff --git a/model_manager.py b/model_manager.py new file mode 100755 index 0000000..7a2382a --- /dev/null +++ b/model_manager.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python3 +"""Flash-MoE Model Manager — list, search, and download compatible models. + +Compatible models: Qwen3.5 MoE with MLX quantization (model_type: qwen3_5_moe). +These include GatedDeltaNet linear attention + full attention layers with +switch_mlp expert routing. + +Usage: + python model_manager.py # List local + search remote + python model_manager.py --local # List local models only + python model_manager.py --search # Search HuggingFace for compatible models + python model_manager.py --download # Download a specific model + python model_manager.py --check # Check if a local model is compatible +""" + +import argparse +import json +import os +import struct +import subprocess +import sys +from pathlib import Path + +try: + import requests +except ImportError: + requests = None + +HF_CACHE = Path(os.path.expanduser("~/.cache/huggingface/hub")) +HF_API = "https://huggingface.co/api" + +# Known compatible model types +COMPATIBLE_MODEL_TYPES = {"qwen3_5_moe"} + +# Search queries — MLX-quantized Qwen3.5 models +SEARCH_QUERIES = [ + "mlx-community Qwen3.5", + "mlx Qwen3.5 MoE", + "lmstudio-community Qwen3.5 MLX", +] + +# MoE model name patterns: "35B-A3B", "122B-A10B", "397B-A17B" etc. +# The "-AB" suffix indicates active parameters = MoE architecture +import re +MOE_PATTERN = re.compile(r'\d+B-A\d+B') + + +def find_config_json(model_path: Path) -> Path | None: + """Find config.json in a model directory, handling HF cache layout.""" + direct = model_path / "config.json" + if direct.exists(): + return direct + snapshots = model_path / "snapshots" + if snapshots.exists(): + for snap in sorted(snapshots.iterdir(), reverse=True): + candidate = snap / "config.json" + if candidate.exists(): + return candidate + return None + + +def check_compatibility(model_path: Path) -> dict: + """Check if a local model is compatible with Flash-MoE. + + Returns a dict with: + compatible: bool + reason: str (if not compatible) + info: dict (model details if compatible) + """ + config_path = find_config_json(model_path) + if not config_path: + return {"compatible": False, "reason": "No config.json found"} + + with open(config_path) as f: + config = json.load(f) + + model_type = config.get("model_type", "") + if model_type not in COMPATIBLE_MODEL_TYPES: + return { + "compatible": False, + "reason": f"Incompatible model_type: {model_type} (need: {', '.join(COMPATIBLE_MODEL_TYPES)})", + } + + tc = config.get("text_config", {}) + if not tc: + return {"compatible": False, "reason": "Missing text_config in config.json"} + + # Check for required fields + required = [ + "hidden_size", "num_hidden_layers", "num_experts", + "num_experts_per_tok", "moe_intermediate_size", + ] + missing = [k for k in required if k not in tc] + if missing: + return {"compatible": False, "reason": f"Missing fields: {', '.join(missing)}"} + + # Check quantization + qc = config.get("quantization_config", config.get("quantization", {})) + bits = qc.get("bits", "?") + group_size = qc.get("group_size", "?") + + # Check for packed experts + model_dir = config_path.parent + has_packed = (model_dir / "packed_experts").exists() or any( + (model_dir.parent / "packed_experts").exists() + for _ in [None] + ) + + # Check for extracted weights + # Look relative to cwd (where infer runs from) + has_weights = Path("metal_infer/model_weights.bin").exists() or Path("model_weights.bin").exists() + + info = { + "model_type": model_type, + "hidden_size": tc.get("hidden_size"), + "num_layers": tc.get("num_hidden_layers"), + "num_experts": tc.get("num_experts"), + "experts_per_tok": tc.get("num_experts_per_tok"), + "moe_intermediate": tc.get("moe_intermediate_size"), + "vocab_size": tc.get("vocab_size"), + "bits": bits, + "group_size": group_size, + "has_packed_experts": has_packed, + "has_extracted_weights": has_weights, + "config_path": str(config_path), + } + + # Estimate sizes + ne = tc.get("num_experts", 0) + nl = tc.get("num_hidden_layers", 0) + mid = tc.get("moe_intermediate_size", 0) + hid = tc.get("hidden_size", 0) + if isinstance(bits, int) and bits > 0: + vals_per_u32 = 32 // bits + expert_bytes = 0 + # gate + up: [mid, hid] + for _ in range(2): + w = mid * ((hid + vals_per_u32 - 1) // vals_per_u32) * 4 + s = mid * ((hid + group_size - 1) // group_size) * 2 + expert_bytes += w + s + s # weight + scales + biases + # down: [hid, mid] + w = hid * ((mid + vals_per_u32 - 1) // vals_per_u32) * 4 + s = hid * ((mid + group_size - 1) // group_size) * 2 + expert_bytes += w + s + s + + total_expert_gb = ne * nl * expert_bytes / (1024**3) + active_per_token_mb = tc.get("num_experts_per_tok", 0) * expert_bytes / (1024**2) + info["expert_size_bytes"] = expert_bytes + info["total_expert_disk_gb"] = round(total_expert_gb, 1) + info["active_per_token_mb"] = round(active_per_token_mb, 1) + + # Count total params (rough estimate) + total_params_b = ne * nl * mid * hid * 3 * 2 / 1e9 # gate+up+down, *2 for bidir + info["approx_total_params"] = f"~{total_params_b:.0f}B" if total_params_b > 1 else f"~{total_params_b*1000:.0f}M" + + return {"compatible": True, "info": info} + + +def list_local_models(): + """List locally cached HuggingFace models and check compatibility.""" + if not HF_CACHE.exists(): + print("No HuggingFace cache found at", HF_CACHE) + return [] + + models = [] + for entry in sorted(HF_CACHE.iterdir()): + if not entry.name.startswith("models--"): + continue + # Convert models--org--name to org/name + parts = entry.name.split("--", 2) + if len(parts) >= 3: + repo_id = f"{parts[1]}/{parts[2]}" + else: + repo_id = entry.name + + result = check_compatibility(entry) + result["repo_id"] = repo_id + result["path"] = str(entry) + models.append(result) + + return models + + +def search_remote_models(): + """Search HuggingFace for compatible Qwen3.5 MoE models.""" + if not requests: + print("Install 'requests' to search HuggingFace: pip install requests") + return [] + + seen = set() + results = [] + + for query in SEARCH_QUERIES: + try: + resp = requests.get( + f"{HF_API}/models", + params={ + "search": query, + "limit": 30, + "sort": "downloads", + "direction": -1, + }, + timeout=10, + ) + resp.raise_for_status() + for model in resp.json(): + repo_id = model.get("id", "") + if repo_id in seen: + continue + seen.add(repo_id) + + # Filter: must have "qwen" and "moe" or "3.5" indicators + lower = repo_id.lower() + tags = [t.lower() for t in model.get("tags", [])] + + is_qwen35 = "qwen3.5" in lower or "qwen3_5" in lower + is_mlx = "mlx" in lower or "mlx" in " ".join(tags) + is_moe = bool(MOE_PATTERN.search(repo_id)) + + # We need: Qwen3.5 + MLX quantized + MoE architecture + if is_qwen35 and is_mlx and is_moe: + # Extract quant info from name + quant = "" + for q in ["3bit", "4bit", "6bit", "8bit"]: + if q in lower: + quant = q + break + + results.append({ + "repo_id": repo_id, + "downloads": model.get("downloads", 0), + "likes": model.get("likes", 0), + "quant": quant, + "last_modified": model.get("lastModified", "")[:10], + }) + except Exception as e: + print(f"Warning: search failed for '{query}': {e}", file=sys.stderr) + + return results + + +def download_model(repo_id: str): + """Download a model from HuggingFace.""" + # Try huggingface-cli first + hf_cli = None + for cmd in ["huggingface-cli", "hf"]: + try: + subprocess.run([cmd, "--help"], capture_output=True, check=True) + hf_cli = cmd + break + except (FileNotFoundError, subprocess.CalledProcessError): + continue + + if hf_cli: + print(f"Downloading {repo_id} via {hf_cli}...") + subprocess.run([hf_cli, "download", repo_id], check=True) + else: + # Fall back to Python + try: + from huggingface_hub import snapshot_download + print(f"Downloading {repo_id} via huggingface_hub...") + path = snapshot_download(repo_id) + print(f"Downloaded to: {path}") + except ImportError: + print("ERROR: No download tool available.") + print("Install one of:") + print(" pip install huggingface-hub # Python library") + print(" pip install huggingface-cli # CLI tool") + print(f"\nOr manually: git clone https://huggingface.co/{repo_id}") + sys.exit(1) + + +def format_size(gb: float) -> str: + if gb >= 1: + return f"{gb:.1f} GB" + return f"{gb * 1024:.0f} MB" + + +def print_model_info(info: dict, indent: str = " "): + """Print formatted model info.""" + print(f"{indent}Architecture: {info['num_layers']} layers, " + f"hidden={info['hidden_size']}, " + f"{info['num_experts']} experts (K={info['experts_per_tok']})") + print(f"{indent}Quantization: {info['bits']}-bit, group_size={info['group_size']}") + if "total_expert_disk_gb" in info: + print(f"{indent}Expert data: {format_size(info['total_expert_disk_gb'])} on disk, " + f"{info['active_per_token_mb']:.1f} MB active/token") + if "approx_total_params" in info: + print(f"{indent}Parameters: {info['approx_total_params']} total") + + # Readiness indicators + ready = True + if not info.get("has_packed_experts"): + print(f"{indent}Packed experts: NOT FOUND (run repack_experts.py)") + ready = False + else: + print(f"{indent}Packed experts: OK") + if not info.get("has_extracted_weights"): + print(f"{indent}Weights file: NOT FOUND (run extract_weights.py)") + ready = False + else: + print(f"{indent}Weights file: OK") + + if ready: + print(f"{indent}Status: READY TO RUN") + else: + print(f"{indent}Status: NEEDS PREPARATION (see above)") + + +def main(): + parser = argparse.ArgumentParser( + description="Flash-MoE Model Manager — list, search, and download compatible models" + ) + parser.add_argument("--local", action="store_true", help="List local models only") + parser.add_argument("--search", action="store_true", help="Search HuggingFace only") + parser.add_argument("--download", type=str, metavar="REPO", help="Download a model (e.g. mlx-community/Qwen3.5-35B-A3B-4bit)") + parser.add_argument("--check", type=str, metavar="PATH", help="Check if a local model is compatible") + args = parser.parse_args() + + if args.check: + path = Path(args.check).expanduser() + result = check_compatibility(path) + if result["compatible"]: + print(f"COMPATIBLE: {path}") + print_model_info(result["info"]) + else: + print(f"NOT COMPATIBLE: {result['reason']}") + return + + if args.download: + download_model(args.download) + # Check compatibility after download + cache_name = "models--" + args.download.replace("/", "--") + cache_path = HF_CACHE / cache_name + if cache_path.exists(): + result = check_compatibility(cache_path) + if result["compatible"]: + print(f"\nModel is compatible with Flash-MoE!") + print_model_info(result["info"]) + print(f"\nNext steps:") + print(f" 1. python repack_experts.py --model {cache_path}") + print(f" 2. python metal_infer/extract_weights.py --model {cache_path}") + print(f" 3. ./metal_infer/infer --model {cache_path} --prompt 'Hello' --tokens 20") + return + + # Default: show both local and remote + show_local = not args.search + show_remote = not args.local + + if show_local: + print("=" * 60) + print("LOCAL MODELS") + print("=" * 60) + models = list_local_models() + if not models: + print(" No models found in", HF_CACHE) + else: + compatible_count = 0 + for m in models: + if m["compatible"]: + compatible_count += 1 + print(f"\n {m['repo_id']}") + print_model_info(m["info"], indent=" ") + print(f" Path: {m['path']}") + else: + print(f"\n {m['repo_id']} (incompatible: {m.get('reason', 'unknown')})") + print(f"\n {compatible_count}/{len(models)} compatible models found") + + if show_remote: + print() + print("=" * 60) + print("AVAILABLE ON HUGGINGFACE") + print("=" * 60) + remote = search_remote_models() + if not remote: + print(" No compatible models found (or search failed)") + else: + # Mark which ones are already local + local_repos = set() + if HF_CACHE.exists(): + for entry in HF_CACHE.iterdir(): + if entry.name.startswith("models--"): + parts = entry.name.split("--", 2) + if len(parts) >= 3: + local_repos.add(f"{parts[1]}/{parts[2]}") + + for m in remote: + local_tag = " [LOCAL]" if m["repo_id"] in local_repos else "" + quant_tag = f" [{m['quant']}]" if m.get("quant") else "" + print(f"\n {m['repo_id']}{local_tag}{quant_tag}") + print(f" Downloads: {m['downloads']:,} Likes: {m['likes']}") + + print(f"\n {len(remote)} models found") + print(f"\n Download with: python model_manager.py --download ") + + +if __name__ == "__main__": + main() From 045ab55eceeb9830872f91f21ba498da49852461 Mon Sep 17 00:00:00 2001 From: Alessio Delmonti Date: Fri, 20 Mar 2026 23:13:40 +0000 Subject: [PATCH 7/7] docs: update README for multi-model support and model manager Add compatible models table, model manager usage instructions, updated quick start with --model flag and FLASH_MOE_MODEL env var, revised project structure, and generalized architecture description. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 91 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 15 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index e3066a9..e158512 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,15 +1,32 @@ -# Flash-MoE: Running a 397B Parameter Model on a Laptop +# Flash-MoE: Running Massive MoE Models on a Laptop > **[Read the paper](paper/flash_moe.pdf)** — Full technical details, 90+ experiments, and the story of how an AI and a human built this in 24 hours. -Pure C/Metal inference engine that runs **Qwen3.5-397B-A17B** (a 397 billion parameter Mixture-of-Experts model) on a MacBook Pro with 48GB RAM at **4.4+ tokens/second** with production-quality output including tool calling. +Pure C/Metal inference engine for **Qwen3.5 Mixture-of-Experts** models on Apple Silicon. Runs models from 35B to 397B parameters on machines with as little as 24GB RAM, streaming expert weights from SSD through a custom Metal compute pipeline. -The entire 209GB model streams from SSD through a custom Metal compute pipeline. No Python. No frameworks. Just C, Objective-C, and hand-tuned Metal shaders. +No Python runtime. No frameworks. Just C, Objective-C, and hand-tuned Metal shaders. Model architecture is auto-detected from HuggingFace `config.json` — switch models with a single `--model` flag. + +## Compatible Models + +Any **Qwen3.5 MoE** model with MLX quantization (`model_type: qwen3_5_moe`) is supported. Use the model manager to discover and download compatible models: + +| Model | Params | Active | Quant | Disk | Min RAM | +|-------|--------|--------|-------|------|---------| +| Qwen3.5-35B-A3B | 35B | 3B | 4-bit | ~18GB | 24GB | +| Qwen3.5-35B-A3B | 35B | 3B | 8-bit | ~35GB | 48GB | +| Qwen3.5-122B-A10B | 122B | 10B | 4-bit | ~65GB | 48GB | +| Qwen3.5-397B-A17B | 397B | 17B | 4-bit | ~209GB | 48GB | +| Qwen3.5-397B-A17B | 397B | 17B | 6-bit | ~280GB | 64GB | +| Qwen3.5-397B-A17B | 397B | 17B | 8-bit | ~397GB | 96GB | + +The engine auto-detects architecture, dimensions, expert counts, quantization, and layer types from `config.json`. No recompilation needed. ## Results ![Progress](progress.png) +Results below are for Qwen3.5-397B-A17B on MacBook Pro M3 Max (48GB): + | Configuration | tok/s | Quality | Notes | |--------------|-------|---------|-------| | 4-bit experts, FMA kernel | **4.36** | Excellent | Current best. Full tool calling. 209GB on disk. | @@ -29,7 +46,7 @@ The entire 209GB model streams from SSD through a custom Metal compute pipeline. ## Architecture -The model has 60 transformer layers: 45 GatedDeltaNet (linear attention) + 15 standard full attention. Each layer has 512 experts, of which K=4 are activated per token (plus one shared expert). Hidden dimension is 4096. +Qwen3.5 MoE models use a hybrid attention architecture with GatedDeltaNet (linear attention) and standard full attention layers, each containing a Mixture-of-Experts MLP. Model dimensions, expert counts, and layer types vary per model and are read from `config.json` at startup. For example, the 397B model has 60 layers (45 linear + 15 full), 512 experts (K=4 active), hidden dim 4096; the 35B model has 40 layers (30 linear + 10 full), 256 experts (K=8 active), hidden dim 2048. ### Key Techniques @@ -66,12 +83,52 @@ CMD3(prev) → CMD1: attention projections + delta-net [1.22ms GPU] On Apple Silicon, SSD DMA and GPU compute share the same memory controller and cannot be profitably overlapped. The GPU's dequant kernels are bandwidth-saturated at ~418 GiB/s. Even small background SSD DMA causes disproportionate GPU latency spikes through memory controller arbitration. The serial pipeline (GPU → SSD → GPU) is hardware-optimal. +## Model Manager + +The model manager helps you find, download, and validate compatible models: + +```bash +# List local models and search HuggingFace for compatible ones +python model_manager.py + +# Search HuggingFace only +python model_manager.py --search + +# List local models only +python model_manager.py --local + +# Download a specific model +python model_manager.py --download mlx-community/Qwen3.5-35B-A3B-4bit + +# Check if a local model is compatible +python model_manager.py --check /path/to/model +``` + +After downloading, prepare the model for inference: + +```bash +# 1. Pack expert weights into per-expert files +python repack_experts.py --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit + +# 2. Extract non-expert weights into a single binary +python metal_infer/extract_weights.py --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit + +# 3. Run inference +cd metal_infer && ./infer --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit --prompt "Hello" --tokens 20 +``` + ## Quick Start ```bash cd metal_infer make -# 4-bit inference (needs packed_experts/ directory) + +# Run with a specific model (auto-detects architecture from config.json) +./infer --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit \ + --prompt "Explain quantum computing" --tokens 100 + +# Or set FLASH_MOE_MODEL to avoid passing --model every time +export FLASH_MOE_MODEL=~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit ./infer --prompt "Explain quantum computing" --tokens 100 # 2-bit inference (faster but breaks tool calling) @@ -87,8 +144,16 @@ make ## Project Structure ``` +model_manager.py # Model discovery, download, and compatibility checking +repack_experts.py # 4-bit expert packing from safetensors +progress.py # Results visualization (Q2/Q4 tracks) +results.tsv # Experiment log (58 experiments) + metal_infer/ - infer.m # Complete inference engine (~7000 lines) + infer.m # Complete inference engine (~7500 lines) + # - ModelConfig struct + config.json parser + # - Runtime model auto-detection + # - Metal compute pipeline shaders.metal # Metal compute kernels (~1200 lines) chat.m # Interactive chat TUI with tool calling tokenizer.h # C BPE tokenizer (single-header, 449 lines) @@ -97,14 +162,10 @@ metal_infer/ extract_weights.py # Creates model_weights.bin from safetensors repack_experts_2bit.py # 4-bit → 2-bit expert requantization train_predictor.py # Expert routing prediction analysis - model_weights.bin # Non-expert weights (5.5GB, mmap'd) + model_weights.bin # Non-expert weights (model-specific, mmap'd) model_weights.json # Tensor manifest vocab.bin # Vocabulary for token decoding tokenizer.bin # Pre-exported BPE tokenizer data - -repack_experts.py # 4-bit expert packing from safetensors -progress.py # Results visualization (Q2/Q4 tracks) -results.tsv # Experiment log (58 experiments) ``` ## What We Tried (and What Worked) @@ -140,8 +201,8 @@ results.tsv # Experiment log (58 experiments) ## Safety This is a primary development machine. The engine explicitly controls memory: -- Non-expert weights: 5.5GB (mmap'd, read-only) +- Non-expert weights: model-dependent (e.g., 5.5GB for 397B, ~1.5GB for 35B, mmap'd read-only) - Metal scratch buffers: ~200MB -- Total: ~6GB, leaving 42GB for OS + page cache -- No OOM risk. Expert data streams from SSD on demand. -- No custom caches. Trust the OS. +- Expert data streams from SSD on demand — no full model load required +- No custom caches. Trust the OS page cache for expert LRU. +- Minimum RAM: 24GB (35B-A3B 4-bit), 48GB (397B-A17B 4-bit)