diff --git a/CLAUDE.md b/CLAUDE.md
index e3066a9..e158512 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -1,15 +1,32 @@
-# Flash-MoE: Running a 397B Parameter Model on a Laptop
+# Flash-MoE: Running Massive MoE Models on a Laptop
> **[Read the paper](paper/flash_moe.pdf)** — Full technical details, 90+ experiments, and the story of how an AI and a human built this in 24 hours.
-Pure C/Metal inference engine that runs **Qwen3.5-397B-A17B** (a 397 billion parameter Mixture-of-Experts model) on a MacBook Pro with 48GB RAM at **4.4+ tokens/second** with production-quality output including tool calling.
+Pure C/Metal inference engine for **Qwen3.5 Mixture-of-Experts** models on Apple Silicon. Runs models from 35B to 397B parameters on machines with as little as 24GB RAM, streaming expert weights from SSD through a custom Metal compute pipeline.
-The entire 209GB model streams from SSD through a custom Metal compute pipeline. No Python. No frameworks. Just C, Objective-C, and hand-tuned Metal shaders.
+No Python runtime. No frameworks. Just C, Objective-C, and hand-tuned Metal shaders. Model architecture is auto-detected from HuggingFace `config.json` — switch models with a single `--model` flag.
+
+## Compatible Models
+
+Any **Qwen3.5 MoE** model with MLX quantization (`model_type: qwen3_5_moe`) is supported. Use the model manager to discover and download compatible models:
+
+| Model | Params | Active | Quant | Disk | Min RAM |
+|-------|--------|--------|-------|------|---------|
+| Qwen3.5-35B-A3B | 35B | 3B | 4-bit | ~18GB | 24GB |
+| Qwen3.5-35B-A3B | 35B | 3B | 8-bit | ~35GB | 48GB |
+| Qwen3.5-122B-A10B | 122B | 10B | 4-bit | ~65GB | 48GB |
+| Qwen3.5-397B-A17B | 397B | 17B | 4-bit | ~209GB | 48GB |
+| Qwen3.5-397B-A17B | 397B | 17B | 6-bit | ~280GB | 64GB |
+| Qwen3.5-397B-A17B | 397B | 17B | 8-bit | ~397GB | 96GB |
+
+The engine auto-detects architecture, dimensions, expert counts, quantization, and layer types from `config.json`. No recompilation needed.
## Results

+Results below are for Qwen3.5-397B-A17B on MacBook Pro M3 Max (48GB):
+
| Configuration | tok/s | Quality | Notes |
|--------------|-------|---------|-------|
| 4-bit experts, FMA kernel | **4.36** | Excellent | Current best. Full tool calling. 209GB on disk. |
@@ -29,7 +46,7 @@ The entire 209GB model streams from SSD through a custom Metal compute pipeline.
## Architecture
-The model has 60 transformer layers: 45 GatedDeltaNet (linear attention) + 15 standard full attention. Each layer has 512 experts, of which K=4 are activated per token (plus one shared expert). Hidden dimension is 4096.
+Qwen3.5 MoE models use a hybrid attention architecture with GatedDeltaNet (linear attention) and standard full attention layers, each containing a Mixture-of-Experts MLP. Model dimensions, expert counts, and layer types vary per model and are read from `config.json` at startup. For example, the 397B model has 60 layers (45 linear + 15 full), 512 experts (K=4 active), hidden dim 4096; the 35B model has 40 layers (30 linear + 10 full), 256 experts (K=8 active), hidden dim 2048.
### Key Techniques
@@ -66,12 +83,52 @@ CMD3(prev) → CMD1: attention projections + delta-net [1.22ms GPU]
On Apple Silicon, SSD DMA and GPU compute share the same memory controller and cannot be profitably overlapped. The GPU's dequant kernels are bandwidth-saturated at ~418 GiB/s. Even small background SSD DMA causes disproportionate GPU latency spikes through memory controller arbitration. The serial pipeline (GPU → SSD → GPU) is hardware-optimal.
+## Model Manager
+
+The model manager helps you find, download, and validate compatible models:
+
+```bash
+# List local models and search HuggingFace for compatible ones
+python model_manager.py
+
+# Search HuggingFace only
+python model_manager.py --search
+
+# List local models only
+python model_manager.py --local
+
+# Download a specific model
+python model_manager.py --download mlx-community/Qwen3.5-35B-A3B-4bit
+
+# Check if a local model is compatible
+python model_manager.py --check /path/to/model
+```
+
+After downloading, prepare the model for inference:
+
+```bash
+# 1. Pack expert weights into per-expert files
+python repack_experts.py --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit
+
+# 2. Extract non-expert weights into a single binary
+python metal_infer/extract_weights.py --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit
+
+# 3. Run inference
+cd metal_infer && ./infer --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit --prompt "Hello" --tokens 20
+```
+
## Quick Start
```bash
cd metal_infer
make
-# 4-bit inference (needs packed_experts/ directory)
+
+# Run with a specific model (auto-detects architecture from config.json)
+./infer --model ~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit \
+ --prompt "Explain quantum computing" --tokens 100
+
+# Or set FLASH_MOE_MODEL to avoid passing --model every time
+export FLASH_MOE_MODEL=~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-35B-A3B-4bit
./infer --prompt "Explain quantum computing" --tokens 100
# 2-bit inference (faster but breaks tool calling)
@@ -87,8 +144,16 @@ make
## Project Structure
```
+model_manager.py # Model discovery, download, and compatibility checking
+repack_experts.py # 4-bit expert packing from safetensors
+progress.py # Results visualization (Q2/Q4 tracks)
+results.tsv # Experiment log (58 experiments)
+
metal_infer/
- infer.m # Complete inference engine (~7000 lines)
+ infer.m # Complete inference engine (~7500 lines)
+ # - ModelConfig struct + config.json parser
+ # - Runtime model auto-detection
+ # - Metal compute pipeline
shaders.metal # Metal compute kernels (~1200 lines)
chat.m # Interactive chat TUI with tool calling
tokenizer.h # C BPE tokenizer (single-header, 449 lines)
@@ -97,14 +162,10 @@ metal_infer/
extract_weights.py # Creates model_weights.bin from safetensors
repack_experts_2bit.py # 4-bit → 2-bit expert requantization
train_predictor.py # Expert routing prediction analysis
- model_weights.bin # Non-expert weights (5.5GB, mmap'd)
+ model_weights.bin # Non-expert weights (model-specific, mmap'd)
model_weights.json # Tensor manifest
vocab.bin # Vocabulary for token decoding
tokenizer.bin # Pre-exported BPE tokenizer data
-
-repack_experts.py # 4-bit expert packing from safetensors
-progress.py # Results visualization (Q2/Q4 tracks)
-results.tsv # Experiment log (58 experiments)
```
## What We Tried (and What Worked)
@@ -140,8 +201,8 @@ results.tsv # Experiment log (58 experiments)
## Safety
This is a primary development machine. The engine explicitly controls memory:
-- Non-expert weights: 5.5GB (mmap'd, read-only)
+- Non-expert weights: model-dependent (e.g., 5.5GB for 397B, ~1.5GB for 35B, mmap'd read-only)
- Metal scratch buffers: ~200MB
-- Total: ~6GB, leaving 42GB for OS + page cache
-- No OOM risk. Expert data streams from SSD on demand.
-- No custom caches. Trust the OS.
+- Expert data streams from SSD on demand — no full model load required
+- No custom caches. Trust the OS page cache for expert LRU.
+- Minimum RAM: 24GB (35B-A3B 4-bit), 48GB (397B-A17B 4-bit)
diff --git a/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md b/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md
new file mode 100644
index 0000000..d967f8b
--- /dev/null
+++ b/docs/superpowers/specs/2026-03-20-runtime-model-config-design.md
@@ -0,0 +1,289 @@
+# Runtime Model Config from HuggingFace config.json
+
+**Date:** 2026-03-20
+**Status:** Approved
+
+## Problem
+
+Switching Flash-MoE between models (e.g., Qwen3.5-397B-A17B vs Qwen3.5-35B-A3B) requires manually editing ~40 `#define` constants and recompiling. This is error-prone (the NaN bug was caused by stale expert offsets) and prevents runtime model selection.
+
+## Solution
+
+Replace all model-specific `#define` constants with a global `ModelConfig` struct populated at startup by parsing the HuggingFace `config.json` and `tokenizer_config.json` using NSJSONSerialization. Expert byte offsets are computed from dimensions and quantization parameters. The `--model` CLI flag selects which model to load.
+
+## Design
+
+### ModelConfig struct
+
+A single global `ModelConfig cfg` replaces all model `#define`s:
+
+```c
+typedef struct {
+ // Core architecture (from config.json -> text_config)
+ int hidden_dim; // text_config.hidden_size
+ int num_layers; // text_config.num_hidden_layers
+ int num_attn_heads; // text_config.num_attention_heads
+ int num_kv_heads; // text_config.num_key_value_heads
+ int head_dim; // text_config.head_dim
+ int vocab_size; // text_config.vocab_size
+ float rms_norm_eps; // text_config.rms_norm_eps
+
+ // MoE (from config.json -> text_config)
+ int num_experts; // text_config.num_experts
+ int num_experts_per_tok; // text_config.num_experts_per_tok
+ int moe_intermediate; // text_config.moe_intermediate_size
+ int shared_intermediate; // text_config.shared_expert_intermediate_size
+ int group_size; // quantization.group_size (or quantization_config)
+ int bits; // quantization.bits (or quantization_config)
+
+ // Linear attention / GatedDeltaNet (from config.json -> text_config)
+ int linear_num_v_heads; // text_config.linear_num_value_heads
+ int linear_num_k_heads; // text_config.linear_num_key_heads
+ int linear_key_dim; // text_config.linear_key_head_dim
+ int linear_value_dim; // text_config.linear_value_head_dim
+ int conv_kernel_size; // text_config.linear_conv_kernel_dim
+
+ // Full attention (from config.json -> text_config)
+ float rope_theta; // text_config.rope_parameters.rope_theta
+ float partial_rotary; // text_config.rope_parameters.partial_rotary_factor
+
+ // Layer type map (from config.json -> text_config.layer_types)
+ int num_full_attn_layers;
+ int num_linear_layers;
+ bool *is_full_attn; // [num_layers] — true if full attention
+ int *full_attn_index; // [num_layers] — index into full-attn arrays, or -1
+ int *linear_index; // [num_layers] — index into linear-attn arrays, or -1
+
+ // Derived: expert byte offsets (computed from dims + quantization)
+ size_t expert_size_4bit;
+ size_t gate_w_off_4, gate_s_off_4, gate_b_off_4;
+ size_t up_w_off_4, up_s_off_4, up_b_off_4;
+ size_t down_w_off_4, down_s_off_4, down_b_off_4;
+ size_t expert_size_2bit;
+ size_t gate_w_off_2, gate_s_off_2, gate_b_off_2;
+ size_t up_w_off_2, up_s_off_2, up_b_off_2;
+ size_t down_w_off_2, down_s_off_2, down_b_off_2;
+
+ // Derived dimensions (computed from above)
+ int linear_total_key; // linear_num_k_heads * linear_key_dim
+ int linear_total_value; // linear_num_v_heads * linear_value_dim
+ int linear_conv_dim; // linear_total_key * 2 + linear_total_value
+ int rotary_dim; // (int)(head_dim * partial_rotary)
+
+ // Special tokens
+ int eos_token_ids[8]; // from config.json eos_token_id (can be array)
+ int num_eos_tokens;
+ int think_start_token; // from tokenizer_config.json added_tokens_decoder
+ int think_end_token; // from tokenizer_config.json added_tokens_decoder
+
+ // Context limits (kept as defaults, could be overridden by CLI)
+ int max_seq_len; // default: text_config.max_position_embeddings
+ int gpu_kv_seq; // default: 8192 (pre-allocation, not model-specific)
+
+ // Model path (resolved HF snapshot dir)
+ char model_path[1024];
+} ModelConfig;
+
+static ModelConfig cfg;
+```
+
+### Config loading function
+
+```c
+static void load_model_config(const char *model_dir);
+```
+
+Steps:
+1. Resolve HF snapshot directory (walk `snapshots/` if needed, same logic as existing code)
+2. Read and parse `config.json` via NSJSONSerialization
+3. Extract `text_config` sub-dictionary for architecture params
+4. Read `layer_types` array to build `is_full_attn[]`, `full_attn_index[]`, `linear_index[]`. Fallback: if `layer_types` is absent but `full_attn_interval` exists, compute `is_full[i] = ((i+1) % interval == 0)`. If neither exists, fatal error.
+5. Read `quantization` or `quantization_config` for group_size and bits
+6. Read `eos_token_id` (handles both single int and array)
+7. Read `rope_parameters` sub-dict for rope_theta and partial_rotary_factor
+8. Read and parse `tokenizer_config.json` for think tokens:
+ - Walk `added_tokens_decoder` object, match entries where `content` == `""` or `""`
+ - Extract their integer key as the token ID
+9. Compute expert byte offsets via `compute_expert_offsets()`
+10. Compute derived dimensions (linear_total_key, etc.)
+11. Print summary to stderr for verification
+
+### Expert offset computation
+
+Offsets are deterministic from `moe_intermediate`, `hidden_dim`, `group_size`, and `bits`:
+
+```c
+static void compute_expert_offsets(ModelConfig *c) {
+ int mid = c->moe_intermediate; // e.g. 512
+ int hid = c->hidden_dim; // e.g. 2048
+ int gs = c->group_size; // e.g. 64
+ int bits = c->bits; // 4 or 2
+
+ // For a [out_dim, in_dim] weight at N bits, group_size gs:
+ // weight_bytes = out_dim * ceil(in_dim / (32/bits)) * 4
+ // scales_bytes = out_dim * ceil(in_dim / gs) * 2 (bf16)
+ // biases_bytes = scales_bytes
+
+ // gate_proj: [mid, hid], up_proj: [mid, hid], down_proj: [hid, mid]
+ // Compute for 4-bit and 2-bit layouts
+ for (int b = 4; b >= 2; b -= 2) {
+ int vals_per_u32 = 32 / b;
+ // gate_proj [mid, hid]
+ size_t gw = (size_t)mid * ((hid + vals_per_u32 - 1) / vals_per_u32) * 4;
+ size_t gs_bytes = (size_t)mid * ((hid + gs - 1) / gs) * 2;
+ size_t gb = gs_bytes;
+ // up_proj [mid, hid] — same shape as gate
+ size_t uw = gw, us = gs_bytes, ub = gb;
+ // down_proj [hid, mid]
+ size_t dw = (size_t)hid * ((mid + vals_per_u32 - 1) / vals_per_u32) * 4;
+ size_t ds = (size_t)hid * ((mid + gs - 1) / gs) * 2;
+ size_t db = ds;
+
+ size_t off = 0;
+ if (b == 4) {
+ c->gate_w_off_4 = off; off += gw;
+ c->gate_s_off_4 = off; off += gs_bytes;
+ c->gate_b_off_4 = off; off += gb;
+ c->up_w_off_4 = off; off += uw;
+ c->up_s_off_4 = off; off += us;
+ c->up_b_off_4 = off; off += ub;
+ c->down_w_off_4 = off; off += dw;
+ c->down_s_off_4 = off; off += ds;
+ c->down_b_off_4 = off; off += db;
+ c->expert_size_4bit = off;
+ } else {
+ c->gate_w_off_2 = off; off += gw;
+ c->gate_s_off_2 = off; off += gs_bytes;
+ c->gate_b_off_2 = off; off += gb;
+ c->up_w_off_2 = off; off += uw;
+ c->up_s_off_2 = off; off += us;
+ c->up_b_off_2 = off; off += ub;
+ c->down_w_off_2 = off; off += dw;
+ c->down_s_off_2 = off; off += ds;
+ c->down_b_off_2 = off; off += db;
+ c->expert_size_2bit = off;
+ }
+ }
+}
+```
+
+### Static arrays to dynamic allocation
+
+These static arrays use compile-time `NUM_LAYERS`/`NUM_EXPERTS` and must become dynamically allocated after config loading:
+
+| Current | New |
+|---------|-----|
+| `static int g_expert_freq[NUM_LAYERS][NUM_EXPERTS]` | `int *g_expert_freq` (malloc `num_layers * num_experts * sizeof(int)`) |
+| `static uint8_t g_expert_seen[NUM_LAYERS][NUM_EXPERTS/8]` | `uint8_t *g_expert_seen` (malloc `num_layers * ceil(num_experts/8)`) |
+| `static uint8_t g_cache_seen[NUM_LAYERS][NUM_EXPERTS]` | `uint8_t *g_cache_seen` (malloc) |
+| `static uint64_t g_cache_last_touch_token[NUM_LAYERS][NUM_EXPERTS]` | `uint64_t *g_cache_last_touch_token` (malloc) |
+| `static uint64_t g_cache_last_evict_token[NUM_LAYERS][NUM_EXPERTS]` | `uint64_t *g_cache_last_evict_token` (malloc) |
+| `static LayerWeightCache layer_cache[NUM_LAYERS]` | `LayerWeightCache *layer_cache` (malloc `num_layers * sizeof(...)`) |
+| `LZ4IndexEntry *g_lz4_index[NUM_LAYERS]` | `LZ4IndexEntry **g_lz4_index` (malloc `num_layers` pointers) |
+| `g_pred_experts[60][MAX_K]` | `int *g_pred_experts` (malloc `num_layers * MAX_K`, note: currently hardcoded to 60, a latent bug) |
+| `g_pred_count[60]` | `int *g_pred_count` (malloc `num_layers`) |
+| `id buf_kv_k[NUM_FULL_ATTN_LAYERS]` | Dynamically allocated array in MetalCtx |
+| `id buf_kv_v[NUM_FULL_ATTN_LAYERS]` | Same |
+| `id buf_delta_state[NUM_LINEAR_LAYERS]` | Same |
+| `id buf_conv_state[NUM_LINEAR_LAYERS]` | Same |
+| `id buf_multi_expert_data[MAX_K]` | Stays `MAX_K` (hardware limit, not model-specific) |
+| Stack VLAs `gpu_delta_snapshots[NUM_LINEAR_LAYERS]` (serve loop) | `malloc`'d at serve entry, freed at exit |
+| Stack VLAs `gpu_conv_snapshots[NUM_LINEAR_LAYERS]` (serve loop) | Same |
+
+Access pattern changes from `g_expert_freq[layer][expert]` to `g_expert_freq[layer * cfg.num_experts + expert]` (flattened 2D indexing). Helper macros can simplify this:
+
+```c
+#define FREQ(l, e) g_expert_freq[(l) * cfg.num_experts + (e)]
+#define CACHE_SEEN(l, e) g_cache_seen[(l) * cfg.num_experts + (e)]
+```
+
+### MetalCtx changes
+
+The `MetalCtx` struct's fixed-size arrays become pointers, allocated in `metal_setup()` after config is loaded:
+
+```c
+typedef struct {
+ // ...existing fields...
+ id *buf_kv_k; // [num_full_attn_layers]
+ id *buf_kv_v; // [num_full_attn_layers]
+ id *buf_delta_state; // [num_linear_layers]
+ id *buf_conv_state; // [num_linear_layers]
+ // buf_multi_expert_data[MAX_K] stays fixed (MAX_K is hardware limit)
+} MetalCtx;
+```
+
+### What remains as #define
+
+Only non-model-specific constants:
+
+- `MAX_K 8` — maximum supported experts per token (array sizing). Note: `MAX_BATCH_SLOTS` (currently 8) is coupled to `MAX_K` — keep them in sync.
+- `GPU_KV_SEQ 8192` — GPU KV pre-allocation (tuning parameter)
+- `TENSOR_HT_SIZE 8192` — hash table size (implementation detail)
+- `NUM_IO_THREADS 8` — I/O thread pool size (hardware tuning)
+- Metal threadgroup sizes (256, 64, 128) — GPU hardware tuning
+
+### Layer type detection
+
+Currently uses `(i + 1) % FULL_ATTN_INTERVAL == 0`. Replaced with config-driven lookup:
+
+```c
+// Old
+int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0);
+
+// New
+int is_full = cfg.is_full_attn[i];
+```
+
+The `full_attn_index[]` and `linear_index[]` arrays map global layer index to per-type index (used for KV cache / delta-net state buffer indexing). All arithmetic formulas like `(layer_idx + 1) / FULL_ATTN_INTERVAL - 1` must be replaced with `cfg.full_attn_index[i]` / `cfg.linear_index[i]` lookups. This includes `build_layer_cache()`, `fused_layer_forward()`, GPU snapshot save/restore, and any other site using the interval formula.
+
+### Startup sequence
+
+```
+main()
+ -> parse CLI args (--model path)
+ -> load_model_config(model_path) // NEW: populates cfg
+ -> alloc_tracking_arrays() // NEW: malloc g_expert_freq etc.
+ -> metal_setup() // uses cfg.* for buffer sizes
+ -> load_weights() // uses cfg.model_path
+ -> inference loop // uses cfg.* throughout
+```
+
+### Thread safety invariant
+
+`cfg` is immutable after `load_model_config()` returns. It is populated once at startup before any threads are spawned, and is read-only for the entire lifetime of the process. No locking required.
+
+### CLI change
+
+The existing `--model` flag already accepts a path. No new flags needed. The only change is that `load_model_config()` is called with this path before any other initialization. If `--model` is omitted, `load_model_config()` searches for any Qwen model in `~/.cache/huggingface/hub/` or prints a clear error asking the user to provide `--model`.
+
+### Files changed
+
+| File | Change |
+|------|--------|
+| `metal_infer/infer.m` | Remove ~40 `#define`s, add `ModelConfig` struct + `load_model_config()` (~150 lines), convert ~13 static/stack arrays to malloc, update ~200 references from `DEFINE` to `cfg.field`, replace all `FULL_ATTN_INTERVAL` formula sites with array lookups |
+| `metal_infer/chat.m` | No changes needed (pure HTTP/SSE client, references no model constants) |
+| `metal_infer/shaders.metal` | No changes (already parameterized via kernel arguments) |
+
+### Validation
+
+`load_model_config()` prints a summary to stderr on startup:
+
+```
+[config] Qwen3.5-35B-A3B: 40 layers (30 linear + 10 full), 2048 hidden, 256 experts (K=8)
+[config] 4-bit expert size: 1769472 bytes, group_size=64
+[config] EOS tokens: [248046, 248044], think: 248068/248069
+```
+
+This makes it immediately visible which model is loaded and whether the config was parsed correctly.
+
+### Error handling
+
+- Missing `config.json` → fatal error with clear message
+- Missing `tokenizer_config.json` → warning, think tokens default to -1 (disabled)
+- Missing `text_config` key → fatal error
+- Missing optional keys (e.g., `linear_conv_kernel_dim`) → use sensible defaults with warning
+- Computed expert offsets are validated against `expert_index.json` if present
+
+### Note on extract_weights.py
+
+`extract_weights.py` currently hardcodes model parameters (layer types, dimensions) when generating the weight manifest. This is acceptable for now — the manifest is generated once per model. A future improvement could make it config-driven too, but it's out of scope for this spec since the runtime engine is the priority.
diff --git a/metal_infer/infer.m b/metal_infer/infer.m
index 5d2a946..afa69f9 100644
--- a/metal_infer/infer.m
+++ b/metal_infer/infer.m
@@ -1,16 +1,18 @@
/*
- * infer.m — Complete Qwen3.5-397B inference engine using Metal
+ * infer.m — Qwen3.5 MoE inference engine using Metal
*
- * Full forward pass: embedding -> 60 transformer layers -> norm -> lm_head -> sample
+ * Full forward pass: embedding -> N transformer layers -> norm -> lm_head -> sample
+ * Model architecture loaded at runtime from HuggingFace config.json (--model flag).
* Non-expert weights loaded from model_weights.bin (mmap'd at startup)
* Expert weights loaded from packed_experts/ per layer per token (pread)
*
- * Architecture: Qwen3.5-397B-A17B (MoE)
- * - 60 layers: 45 linear attention (GatedDeltaNet) + 15 full attention
- * - hidden_size=4096, head_dim=256, num_attention_heads=32, num_kv_heads=2
- * - 512 experts/layer, 10 active (we use K=4 for speed)
+ * Supported: Qwen3.5-35B-A3B, Qwen3.5-397B-A17B, and compatible MoE variants.
+ * Architecture auto-detected from config.json:
+ * - N layers: mix of linear attention (GatedDeltaNet) + full attention
+ * - Configurable hidden_size, head_dim, num_attention_heads, num_kv_heads
+ * - Variable experts/layer and active experts (K)
* - Shared expert per layer (always active)
- * - Linear attention: conv1d(kernel=4) + gated delta recurrence
+ * - Linear attention: conv1d + gated delta recurrence
* - Full attention: standard QKV + scaled dot product + RoPE
*
* Command buffer optimization (fused_layer_forward):
@@ -66,65 +68,328 @@
#include
// ============================================================================
-// Model constants
+// Runtime model configuration (populated from HuggingFace config.json)
// ============================================================================
-#define HIDDEN_DIM 4096
-#define NUM_LAYERS 60
-#define NUM_ATTN_HEADS 32
-#define NUM_KV_HEADS 2
-#define HEAD_DIM 256
-#define VOCAB_SIZE 248320
-#define RMS_NORM_EPS 1e-6f
-#define NUM_EXPERTS 512
-#define NUM_EXPERTS_PER_TOK 10
-#define MOE_INTERMEDIATE 1024
-#define SHARED_INTERMEDIATE 1024
-#define FULL_ATTN_INTERVAL 4
-#define GROUP_SIZE 64
-#define BITS 4
-
-// Linear attention (GatedDeltaNet) constants
-#define LINEAR_NUM_V_HEADS 64
-#define LINEAR_NUM_K_HEADS 16
-#define LINEAR_KEY_DIM 128 // head_k_dim
-#define LINEAR_VALUE_DIM 128 // head_v_dim
-#define LINEAR_TOTAL_KEY (LINEAR_NUM_K_HEADS * LINEAR_KEY_DIM) // 2048
-#define LINEAR_TOTAL_VALUE (LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM) // 8192
-#define LINEAR_CONV_DIM (LINEAR_TOTAL_KEY * 2 + LINEAR_TOTAL_VALUE) // 12288
-#define CONV_KERNEL_SIZE 4
-
-// Full attention constants
-#define ROPE_THETA 10000000.0f
-#define PARTIAL_ROTARY 0.25f
-#define ROTARY_DIM (int)(HEAD_DIM * PARTIAL_ROTARY) // 64
-
-// Expert packed binary layout (from existing code)
-#define EXPERT_SIZE 7077888
-
-// 2-bit expert layout (from repack_experts_2bit.py)
-#define EXPERT_SIZE_2BIT 3932160
-#define GATE_W_OFF_2 0
-#define GATE_S_OFF_2 1048576
-#define GATE_B_OFF_2 1179648
-#define UP_W_OFF_2 1310720
-#define UP_S_OFF_2 2359296
-#define UP_B_OFF_2 2490368
-#define DOWN_W_OFF_2 2621440
-#define DOWN_S_OFF_2 3670016
-#define DOWN_B_OFF_2 3801088
-
-// KV cache maximum context length
-#define MAX_SEQ_LEN 1048576 // 1M context — only 15 full-attn layers need KV cache, ~15GB at max
-#define GPU_KV_SEQ 8192 // GPU KV buffer pre-allocation (grows if exceeded, falls back to CPU attn)
-
-// Special tokens
-#define EOS_TOKEN_1 248046
-#define EOS_TOKEN_2 248044
-#define THINK_START_TOKEN 248068 //
-#define THINK_END_TOKEN 248069 //
-
-#define MODEL_PATH_DEFAULT "/Users/danielwoods/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots/39159bd8aa74f5c8446d2b2dc584f62bb51cb0d3"
+typedef struct {
+ // Core architecture
+ int hidden_dim;
+ int num_layers;
+ int num_attn_heads;
+ int num_kv_heads;
+ int head_dim;
+ int vocab_size;
+ float rms_norm_eps;
+
+ // MoE
+ int num_experts;
+ int num_experts_per_tok;
+ int moe_intermediate;
+ int shared_intermediate;
+ int group_size;
+ int bits;
+
+ // Linear attention (GatedDeltaNet)
+ int linear_num_v_heads;
+ int linear_num_k_heads;
+ int linear_key_dim;
+ int linear_value_dim;
+ int conv_kernel_size;
+
+ // Full attention
+ float rope_theta;
+ float partial_rotary;
+
+ // Layer type map
+ int num_full_attn_layers;
+ int num_linear_layers;
+ bool *is_full_attn; // [num_layers]
+ int *full_attn_index; // [num_layers] — index into full-attn buffers, or -1
+ int *linear_index; // [num_layers] — index into linear-attn buffers, or -1
+
+ // Derived: expert byte offsets (4-bit)
+ size_t expert_size_4bit;
+ size_t gate_w_off_4, gate_s_off_4, gate_b_off_4;
+ size_t up_w_off_4, up_s_off_4, up_b_off_4;
+ size_t down_w_off_4, down_s_off_4, down_b_off_4;
+
+ // Derived: expert byte offsets (2-bit)
+ size_t expert_size_2bit;
+ size_t gate_w_off_2, gate_s_off_2, gate_b_off_2;
+ size_t up_w_off_2, up_s_off_2, up_b_off_2;
+ size_t down_w_off_2, down_s_off_2, down_b_off_2;
+
+ // Derived dimensions
+ int linear_total_key;
+ int linear_total_value;
+ int linear_conv_dim;
+ int rotary_dim;
+
+ // Special tokens
+ int eos_token_ids[8];
+ int num_eos_tokens;
+ int think_start_token;
+ int think_end_token;
+
+ // Context limits
+ int max_seq_len;
+ int gpu_kv_seq;
+
+ // Model path (resolved)
+ char model_path[1024];
+} ModelConfig;
+
+static ModelConfig cfg;
+
+static void compute_expert_offsets(ModelConfig *c) {
+ int mid = c->moe_intermediate;
+ int hid = c->hidden_dim;
+ int gs = c->group_size;
+
+ for (int b = 4; b >= 2; b -= 2) {
+ int vals_per_u32 = 32 / b;
+ // gate_proj [mid, hid]
+ size_t gw = (size_t)mid * ((hid + vals_per_u32 - 1) / vals_per_u32) * 4;
+ size_t gs_sz = (size_t)mid * ((hid + gs - 1) / gs) * 2;
+ size_t gb = gs_sz;
+ // up_proj [mid, hid] — same shape
+ size_t uw = gw, us = gs_sz, ub = gb;
+ // down_proj [hid, mid]
+ size_t dw = (size_t)hid * ((mid + vals_per_u32 - 1) / vals_per_u32) * 4;
+ size_t ds = (size_t)hid * ((mid + gs - 1) / gs) * 2;
+ size_t db = ds;
+
+ size_t off = 0;
+ if (b == 4) {
+ c->gate_w_off_4 = off; off += gw;
+ c->gate_s_off_4 = off; off += gs_sz;
+ c->gate_b_off_4 = off; off += gb;
+ c->up_w_off_4 = off; off += uw;
+ c->up_s_off_4 = off; off += us;
+ c->up_b_off_4 = off; off += ub;
+ c->down_w_off_4 = off; off += dw;
+ c->down_s_off_4 = off; off += ds;
+ c->down_b_off_4 = off; off += db;
+ c->expert_size_4bit = off;
+ } else {
+ c->gate_w_off_2 = off; off += gw;
+ c->gate_s_off_2 = off; off += gs_sz;
+ c->gate_b_off_2 = off; off += gb;
+ c->up_w_off_2 = off; off += uw;
+ c->up_s_off_2 = off; off += us;
+ c->up_b_off_2 = off; off += ub;
+ c->down_w_off_2 = off; off += dw;
+ c->down_s_off_2 = off; off += ds;
+ c->down_b_off_2 = off; off += db;
+ c->expert_size_2bit = off;
+ }
+ }
+}
+
+static void load_model_config(const char *model_dir) {
+ memset(&cfg, 0, sizeof(cfg));
+ cfg.think_start_token = -1;
+ cfg.think_end_token = -1;
+ cfg.gpu_kv_seq = 8192;
+
+ if (!model_dir || !model_dir[0]) {
+ fprintf(stderr, "FATAL: --model path required\n");
+ exit(1);
+ }
+
+ // Resolve HF snapshot directory
+ NSString *base = [NSString stringWithUTF8String:model_dir];
+ NSString *configPath = [base stringByAppendingPathComponent:@"config.json"];
+ NSFileManager *fm = [NSFileManager defaultManager];
+
+ if (![fm fileExistsAtPath:configPath]) {
+ NSString *snapDir = [base stringByAppendingPathComponent:@"snapshots"];
+ if ([fm fileExistsAtPath:snapDir]) {
+ NSArray *snaps = [[fm contentsOfDirectoryAtPath:snapDir error:nil]
+ sortedArrayUsingSelector:@selector(compare:)];
+ for (NSString *snap in snaps) {
+ NSString *candidate = [[snapDir stringByAppendingPathComponent:snap]
+ stringByAppendingPathComponent:@"config.json"];
+ if ([fm fileExistsAtPath:candidate]) {
+ base = [snapDir stringByAppendingPathComponent:snap];
+ configPath = candidate;
+ break;
+ }
+ }
+ }
+ }
+
+ if (![fm fileExistsAtPath:configPath]) {
+ fprintf(stderr, "FATAL: config.json not found in %s\n", model_dir);
+ exit(1);
+ }
+
+ strlcpy(cfg.model_path, [base UTF8String], sizeof(cfg.model_path));
+
+ // Parse config.json
+ NSData *data = [NSData dataWithContentsOfFile:configPath];
+ NSError *jsonErr = nil;
+ NSDictionary *root = [NSJSONSerialization JSONObjectWithData:data options:0 error:&jsonErr];
+ if (!root) {
+ fprintf(stderr, "FATAL: failed to parse config.json: %s\n", [[jsonErr localizedDescription] UTF8String]);
+ exit(1);
+ }
+ NSDictionary *tc = root[@"text_config"];
+ if (!tc) { fprintf(stderr, "FATAL: config.json missing text_config\n"); exit(1); }
+
+ cfg.hidden_dim = [tc[@"hidden_size"] intValue];
+ cfg.num_layers = [tc[@"num_hidden_layers"] intValue];
+ cfg.num_attn_heads = [tc[@"num_attention_heads"] intValue];
+ cfg.num_kv_heads = [tc[@"num_key_value_heads"] intValue];
+ cfg.head_dim = tc[@"head_dim"] ? [tc[@"head_dim"] intValue] : (cfg.hidden_dim / cfg.num_attn_heads);
+ cfg.vocab_size = [tc[@"vocab_size"] intValue];
+ cfg.rms_norm_eps = [tc[@"rms_norm_eps"] floatValue];
+ cfg.num_experts = [tc[@"num_experts"] intValue];
+ cfg.num_experts_per_tok = [tc[@"num_experts_per_tok"] intValue];
+ cfg.moe_intermediate = [tc[@"moe_intermediate_size"] intValue];
+ cfg.shared_intermediate = [tc[@"shared_expert_intermediate_size"] intValue];
+ cfg.linear_num_v_heads = [tc[@"linear_num_value_heads"] intValue];
+ cfg.linear_num_k_heads = [tc[@"linear_num_key_heads"] intValue];
+ cfg.linear_key_dim = tc[@"linear_key_head_dim"] ? [tc[@"linear_key_head_dim"] intValue] : 128;
+ cfg.linear_value_dim = tc[@"linear_value_head_dim"] ? [tc[@"linear_value_head_dim"] intValue] : 128;
+ cfg.conv_kernel_size = tc[@"linear_conv_kernel_dim"] ? [tc[@"linear_conv_kernel_dim"] intValue] : 4;
+ cfg.max_seq_len = [tc[@"max_position_embeddings"] intValue];
+
+ // Quantization
+ NSDictionary *qc = root[@"quantization_config"] ?: root[@"quantization"];
+ if (qc) {
+ cfg.group_size = [qc[@"group_size"] intValue];
+ cfg.bits = [qc[@"bits"] intValue];
+ } else {
+ cfg.group_size = 64;
+ cfg.bits = 4;
+ fprintf(stderr, "[config] WARNING: no quantization_config, defaulting to 4-bit group_size=64\n");
+ }
+
+ // RoPE parameters
+ NSDictionary *rope = tc[@"rope_parameters"];
+ if (rope) {
+ cfg.rope_theta = [rope[@"rope_theta"] floatValue];
+ cfg.partial_rotary = [rope[@"partial_rotary_factor"] floatValue];
+ } else {
+ cfg.rope_theta = 10000000.0f;
+ cfg.partial_rotary = 0.25f;
+ }
+
+ // Layer types
+ NSArray *layerTypes = tc[@"layer_types"];
+ cfg.is_full_attn = calloc(cfg.num_layers, sizeof(bool));
+ cfg.full_attn_index = malloc(cfg.num_layers * sizeof(int));
+ cfg.linear_index = malloc(cfg.num_layers * sizeof(int));
+
+ if (layerTypes && [layerTypes count] == (NSUInteger)cfg.num_layers) {
+ for (int i = 0; i < cfg.num_layers; i++) {
+ cfg.is_full_attn[i] = [layerTypes[i] isEqualToString:@"full_attention"];
+ }
+ } else {
+ int interval = tc[@"full_attention_interval"] ? [tc[@"full_attention_interval"] intValue] : 4;
+ for (int i = 0; i < cfg.num_layers; i++) {
+ cfg.is_full_attn[i] = ((i + 1) % interval == 0);
+ }
+ fprintf(stderr, "[config] Using full_attn_interval=%d (no explicit layer_types)\n", interval);
+ }
+
+ int full_count = 0, linear_count = 0;
+ for (int i = 0; i < cfg.num_layers; i++) {
+ if (cfg.is_full_attn[i]) {
+ cfg.full_attn_index[i] = full_count++;
+ cfg.linear_index[i] = -1;
+ } else {
+ cfg.linear_index[i] = linear_count++;
+ cfg.full_attn_index[i] = -1;
+ }
+ }
+ cfg.num_full_attn_layers = full_count;
+ cfg.num_linear_layers = linear_count;
+
+ // EOS tokens (can be int or array in config.json)
+ id eosVal = root[@"eos_token_id"];
+ if ([eosVal isKindOfClass:[NSArray class]]) {
+ NSArray *arr = (NSArray *)eosVal;
+ cfg.num_eos_tokens = (int)[arr count];
+ if (cfg.num_eos_tokens > 8) cfg.num_eos_tokens = 8;
+ for (int i = 0; i < cfg.num_eos_tokens; i++)
+ cfg.eos_token_ids[i] = [arr[i] intValue];
+ } else if (eosVal) {
+ cfg.num_eos_tokens = 1;
+ cfg.eos_token_ids[0] = [eosVal intValue];
+ }
+
+ // Think tokens from tokenizer.json added_tokens
+ NSString *tokPath = [base stringByAppendingPathComponent:@"tokenizer.json"];
+ if ([fm fileExistsAtPath:tokPath]) {
+ NSData *tokData = [NSData dataWithContentsOfFile:tokPath];
+ NSDictionary *tokRoot = [NSJSONSerialization JSONObjectWithData:tokData options:0 error:nil];
+ NSArray *addedTokens = tokRoot[@"added_tokens"];
+ if (addedTokens) {
+ for (NSDictionary *tok in addedTokens) {
+ NSString *content = tok[@"content"];
+ int tid = [tok[@"id"] intValue];
+ if ([content isEqualToString:@""]) cfg.think_start_token = tid;
+ else if ([content isEqualToString:@""]) cfg.think_end_token = tid;
+ }
+ }
+ } else {
+ fprintf(stderr, "[config] WARNING: tokenizer.json not found, think tokens disabled\n");
+ }
+
+ // Derived dimensions
+ cfg.linear_total_key = cfg.linear_num_k_heads * cfg.linear_key_dim;
+ cfg.linear_total_value = cfg.linear_num_v_heads * cfg.linear_value_dim;
+ cfg.linear_conv_dim = cfg.linear_total_key * 2 + cfg.linear_total_value;
+ cfg.rotary_dim = (int)(cfg.head_dim * cfg.partial_rotary);
+
+ // Expert byte offsets
+ compute_expert_offsets(&cfg);
+
+ // Summary
+ fprintf(stderr, "[config] %d layers (%d linear + %d full), hidden=%d, heads=%d, kv_heads=%d, head_dim=%d\n",
+ cfg.num_layers, cfg.num_linear_layers, cfg.num_full_attn_layers,
+ cfg.hidden_dim, cfg.num_attn_heads, cfg.num_kv_heads, cfg.head_dim);
+ fprintf(stderr, "[config] %d experts (K=%d), moe_intermediate=%d, shared=%d\n",
+ cfg.num_experts, cfg.num_experts_per_tok, cfg.moe_intermediate, cfg.shared_intermediate);
+ fprintf(stderr, "[config] %d-bit quantization, group_size=%d, expert_size=%zu bytes\n",
+ cfg.bits, cfg.group_size, cfg.expert_size_4bit);
+ fprintf(stderr, "[config] EOS tokens: [");
+ for (int i = 0; i < cfg.num_eos_tokens; i++)
+ fprintf(stderr, "%s%d", i ? ", " : "", cfg.eos_token_ids[i]);
+ fprintf(stderr, "], think: %d/%d\n", cfg.think_start_token, cfg.think_end_token);
+}
+
+// ============================================================================
+// Dynamic tracking arrays (allocated after config is loaded)
+// Declarations here, alloc_tracking_arrays() defined after types below.
+// ============================================================================
+
+static int *g_expert_freq = NULL;
+static uint8_t *g_expert_seen = NULL;
+static void **g_lz4_index = NULL; // actually LZ4IndexEntry**, cast at use site
+static uint8_t *g_cache_seen = NULL;
+static uint64_t *g_cache_last_touch_token = NULL;
+static uint64_t *g_cache_last_evict_token = NULL;
+static int *g_pred_experts = NULL;
+static int *g_pred_count = NULL;
+
+// Hardware tuning constants (not model-specific)
+#define GPU_KV_SEQ 8192
+
+// Helper macros for flattened 2D access
+#define FREQ(l, e) g_expert_freq[(l) * cfg.num_experts + (e)]
+#define EXPERT_SEEN_BYTE(l, e) g_expert_seen[(l) * ((cfg.num_experts + 7) / 8) + ((e) >> 3)]
+#define CACHE_SEEN(l, e) g_cache_seen[(l) * cfg.num_experts + (e)]
+#define CACHE_TOUCH(l, e) g_cache_last_touch_token[(l) * cfg.num_experts + (e)]
+#define CACHE_EVICT(l, e) g_cache_last_evict_token[(l) * cfg.num_experts + (e)]
+#define PRED_EXPERT(l, k) g_pred_experts[(l) * MAX_K + (k)]
+#define PRED_COUNT(l) g_pred_count[(l)]
+
+// Forward declaration — defined after LayerWeightCache and LZ4IndexEntry
+static void alloc_tracking_arrays(void);
+
// ============================================================================
// Timing helper
@@ -182,7 +447,6 @@ static double now_ms(void) {
uint32_t raw_size;
} LZ4IndexEntry;
-static LZ4IndexEntry *g_lz4_index[NUM_LAYERS]; // per-layer index (NULL if not using LZ4)
static void *g_lz4_comp_bufs[8]; // pre-allocated compressed read buffers (MAX_K=8)
static int g_use_lz4 = 0; // auto-detected from packed_experts_lz4/
@@ -190,23 +454,21 @@ static double now_ms(void) {
// Expert frequency tracking (diagnostic: --freq flag)
// ============================================================================
-static int g_expert_freq[NUM_LAYERS][NUM_EXPERTS]; // activation count per (layer, expert)
static int g_freq_tracking = 0; // enabled by --freq flag
static int g_use_2bit = 0; // enabled by --2bit flag: use packed_experts_2bit/ + 2-bit kernel
static int g_cache_telemetry_enabled = 0; // enabled by --cache-telemetry flag
static int g_think_budget = 2048; // max thinking tokens before force-emitting
// Tiered I/O: cold fds (F_NOCACHE) for first reads, warm fds (page cached) for repeats
-static int *g_layer_fds_cold = NULL; // [NUM_LAYERS] cold fds (set in main)
-static uint8_t g_expert_seen[NUM_LAYERS][NUM_EXPERTS / 8]; // bitset: seen before?
+static int *g_layer_fds_cold = NULL; // [cfg.num_layers] cold fds (set in main)
// Async pread state defined after InferPreadTask (see below)
static inline int expert_is_seen(int layer, int expert) {
- return (g_expert_seen[layer][expert >> 3] >> (expert & 7)) & 1;
+ return (EXPERT_SEEN_BYTE(layer, expert) >> (expert & 7)) & 1;
}
static inline void expert_mark_seen(int layer, int expert) {
- g_expert_seen[layer][expert >> 3] |= (1 << (expert & 7));
+ EXPERT_SEEN_BYTE(layer, expert) |= (1 << (expert & 7));
}
// Pick fd for expert read. Currently: always use warm fd (OS page cache).
// Tiered I/O (cold F_NOCACHE for first reads) was tested but OS page cache
@@ -218,7 +480,7 @@ static inline int expert_pick_fd(int layer, int expert, int warm_fd) {
// Active expert size based on quantization mode
static inline size_t active_expert_size(void) {
- return g_use_2bit ? EXPERT_SIZE_2BIT : EXPERT_SIZE;
+ return g_use_2bit ? cfg.expert_size_2bit : cfg.expert_size_4bit;
}
static int g_freq_total_tokens = 0; // total tokens processed while tracking
@@ -238,15 +500,12 @@ static inline size_t active_expert_size(void) {
} CacheTelemetry;
static CacheTelemetry g_cache_telemetry = {0};
-static uint8_t g_cache_seen[NUM_LAYERS][NUM_EXPERTS];
-static uint64_t g_cache_last_touch_token[NUM_LAYERS][NUM_EXPERTS];
-static uint64_t g_cache_last_evict_token[NUM_LAYERS][NUM_EXPERTS];
static void cache_telemetry_reset(void) {
memset(&g_cache_telemetry, 0, sizeof(g_cache_telemetry));
- memset(g_cache_seen, 0, sizeof(g_cache_seen));
- memset(g_cache_last_touch_token, 0, sizeof(g_cache_last_touch_token));
- memset(g_cache_last_evict_token, 0, sizeof(g_cache_last_evict_token));
+ memset(g_cache_seen, 0, cfg.num_layers * cfg.num_experts * sizeof(uint8_t));
+ memset(g_cache_last_touch_token, 0, cfg.num_layers * cfg.num_experts * sizeof(uint64_t));
+ memset(g_cache_last_evict_token, 0, cfg.num_layers * cfg.num_experts * sizeof(uint64_t));
}
static void cache_telemetry_note_token(void) {
@@ -256,27 +515,27 @@ static void cache_telemetry_note_token(void) {
static void cache_telemetry_touch(int layer_idx, int expert_idx) {
if (!g_cache_telemetry_enabled) return;
- if (layer_idx < 0 || layer_idx >= NUM_LAYERS || expert_idx < 0 || expert_idx >= NUM_EXPERTS) return;
- if (!g_cache_seen[layer_idx][expert_idx]) {
- g_cache_seen[layer_idx][expert_idx] = 1;
+ if (layer_idx < 0 || layer_idx >= cfg.num_layers || expert_idx < 0 || expert_idx >= cfg.num_experts) return;
+ if (!CACHE_SEEN(layer_idx, expert_idx)) {
+ CACHE_SEEN(layer_idx, expert_idx) = 1;
g_cache_telemetry.unique_experts_touched++;
}
- g_cache_last_touch_token[layer_idx][expert_idx] = g_cache_telemetry.token_clock;
+ CACHE_TOUCH(layer_idx, expert_idx) = g_cache_telemetry.token_clock;
}
static void cache_telemetry_miss(int layer_idx, int expert_idx) {
if (!g_cache_telemetry_enabled) return;
- if (layer_idx < 0 || layer_idx >= NUM_LAYERS || expert_idx < 0 || expert_idx >= NUM_EXPERTS) return;
- if (!g_cache_seen[layer_idx][expert_idx]) {
+ if (layer_idx < 0 || layer_idx >= cfg.num_layers || expert_idx < 0 || expert_idx >= cfg.num_experts) return;
+ if (!CACHE_SEEN(layer_idx, expert_idx)) {
g_cache_telemetry.cold_misses++;
- g_cache_seen[layer_idx][expert_idx] = 1;
+ CACHE_SEEN(layer_idx, expert_idx) = 1;
g_cache_telemetry.unique_experts_touched++;
} else {
g_cache_telemetry.eviction_misses++;
uint64_t dist = 0;
- if (g_cache_last_evict_token[layer_idx][expert_idx] > 0 &&
- g_cache_telemetry.token_clock >= g_cache_last_evict_token[layer_idx][expert_idx]) {
- dist = g_cache_telemetry.token_clock - g_cache_last_evict_token[layer_idx][expert_idx];
+ if (CACHE_EVICT(layer_idx, expert_idx) > 0 &&
+ g_cache_telemetry.token_clock >= CACHE_EVICT(layer_idx, expert_idx)) {
+ dist = g_cache_telemetry.token_clock - CACHE_EVICT(layer_idx, expert_idx);
}
if (dist <= 1) g_cache_telemetry.reuse_le_1++;
else if (dist <= 4) g_cache_telemetry.reuse_le_4++;
@@ -286,14 +545,14 @@ static void cache_telemetry_miss(int layer_idx, int expert_idx) {
g_cache_telemetry.reuse_distance_sum += dist;
g_cache_telemetry.reuse_distance_samples++;
}
- g_cache_last_touch_token[layer_idx][expert_idx] = g_cache_telemetry.token_clock;
+ CACHE_TOUCH(layer_idx, expert_idx) = g_cache_telemetry.token_clock;
}
static void cache_telemetry_evict(int layer_idx, int expert_idx) {
if (!g_cache_telemetry_enabled) return;
- if (layer_idx < 0 || layer_idx >= NUM_LAYERS || expert_idx < 0 || expert_idx >= NUM_EXPERTS) return;
+ if (layer_idx < 0 || layer_idx >= cfg.num_layers || expert_idx < 0 || expert_idx >= cfg.num_experts) return;
g_cache_telemetry.evictions++;
- g_cache_last_evict_token[layer_idx][expert_idx] = g_cache_telemetry.token_clock;
+ CACHE_EVICT(layer_idx, expert_idx) = g_cache_telemetry.token_clock;
}
static void cache_telemetry_print(uint64_t hits, uint64_t misses) {
@@ -303,8 +562,8 @@ static void cache_telemetry_print(uint64_t hits, uint64_t misses) {
fprintf(stderr, "Tokens tracked: %llu\n", g_cache_telemetry.token_clock);
fprintf(stderr, "Unique experts touched: %llu / %d (%.1f%%)\n",
g_cache_telemetry.unique_experts_touched,
- NUM_LAYERS * NUM_EXPERTS,
- 100.0 * g_cache_telemetry.unique_experts_touched / (NUM_LAYERS * NUM_EXPERTS));
+ cfg.num_layers * cfg.num_experts,
+ 100.0 * g_cache_telemetry.unique_experts_touched / (cfg.num_layers * cfg.num_experts));
fprintf(stderr, "Miss breakdown: cold %llu (%.1f%% of misses), eviction %llu (%.1f%% of misses)\n",
g_cache_telemetry.cold_misses,
misses > 0 ? 100.0 * g_cache_telemetry.cold_misses / misses : 0.0,
@@ -578,6 +837,47 @@ static void build_tensor_ht(TensorManifest *m) {
int num_tokens;
} Vocabulary;
+// GPT-2 BPE byte decoder: convert BPE Unicode chars back to raw bytes.
+// In GPT-2 BPE, bytes 0x00-0xFF are mapped to Unicode codepoints:
+// printable ASCII (0x21-0x7E, 0xA1-0xAC, 0xAE-0xFF) map to themselves
+// everything else maps to U+0100 + offset (e.g., space 0x20 → U+0120 'Ġ')
+// This function decodes a UTF-8 BPE string back to raw bytes in-place.
+static int bpe_decode_inplace(char *s, int len) {
+ int out = 0;
+ int i = 0;
+ while (i < len) {
+ unsigned char c = (unsigned char)s[i];
+ if (c < 0x80) {
+ // ASCII byte — pass through
+ s[out++] = s[i++];
+ } else if ((c & 0xE0) == 0xC0 && i + 1 < len) {
+ // 2-byte UTF-8: U+0080 to U+07FF
+ unsigned int cp = ((c & 0x1F) << 6) | ((unsigned char)s[i+1] & 0x3F);
+ if (cp >= 0x100 && cp <= 0x1FF) {
+ // GPT-2 BPE mapped byte: U+0100+byte → original byte
+ s[out++] = (char)(cp - 0x100);
+ } else {
+ // Regular Unicode char — keep UTF-8 encoding
+ s[out++] = s[i];
+ s[out++] = s[i+1];
+ }
+ i += 2;
+ } else if ((c & 0xF0) == 0xE0 && i + 2 < len) {
+ // 3-byte UTF-8
+ s[out++] = s[i]; s[out++] = s[i+1]; s[out++] = s[i+2];
+ i += 3;
+ } else if ((c & 0xF8) == 0xF0 && i + 3 < len) {
+ // 4-byte UTF-8
+ s[out++] = s[i]; s[out++] = s[i+1]; s[out++] = s[i+2]; s[out++] = s[i+3];
+ i += 4;
+ } else {
+ s[out++] = s[i++];
+ }
+ }
+ s[out] = '\0';
+ return out;
+}
+
static Vocabulary *load_vocab(const char *path) {
FILE *f = fopen(path, "rb");
if (!f) {
@@ -601,7 +901,8 @@ static void build_tensor_ht(TensorManifest *m) {
v->tokens[i] = malloc(byte_len + 1);
fread(v->tokens[i], 1, byte_len, f);
v->tokens[i][byte_len] = '\0';
- v->lengths[i] = byte_len;
+ // Decode GPT-2 BPE byte encoding (Ġ→space, Ċ→newline, etc.)
+ v->lengths[i] = bpe_decode_inplace(v->tokens[i], byte_len);
}
}
@@ -916,53 +1217,52 @@ static void cpu_conv1d_step(
id attn_values_pipe;
id sigmoid_gate_pipe;
// Reusable buffers for attention matmuls
- id buf_input; // input vector [HIDDEN_DIM or max projection input]
+ id buf_input; // input vector [cfg.hidden_dim or max projection input]
id buf_output; // output vector [max projection output]
id wf_buf; // the mmap'd weight file as a Metal buffer
// Batched matmul output slots (preallocated, reused across dispatches)
id batch_out[MAX_BATCH_SLOTS];
// Reusable buffers for expert computation (avoids per-expert alloc)
// Legacy single-expert buffers (kept for gpu_expert_forward compat)
- id buf_expert_data; // holds one expert's packed weights (EXPERT_SIZE bytes)
- id buf_expert_input; // h_post input [HIDDEN_DIM floats]
- id buf_expert_gate; // gate_proj output [MOE_INTERMEDIATE floats]
- id buf_expert_up; // up_proj output [MOE_INTERMEDIATE floats]
- id buf_expert_act; // SwiGLU output [MOE_INTERMEDIATE floats]
- id buf_expert_out; // down_proj output [HIDDEN_DIM floats]
+ id buf_expert_data; // holds one expert's packed weights (cfg.expert_size_4bit bytes)
+ id buf_expert_input; // h_post input [cfg.hidden_dim floats]
+ id buf_expert_gate; // gate_proj output [cfg.moe_intermediate floats]
+ id buf_expert_up; // up_proj output [cfg.moe_intermediate floats]
+ id buf_expert_act; // SwiGLU output [cfg.moe_intermediate floats]
+ id buf_expert_out; // down_proj output [cfg.hidden_dim floats]
// Multi-expert buffers: K independent sets so all experts can be encoded
// into a SINGLE command buffer (no per-expert commit+wait).
// Each expert k uses slot [k].
// Double-buffered: set A (data) for GPU compute, set B (data_B) for background pread.
// Gate/up/act/out only need one set (GPU uses them after pread completes).
#define MAX_K 8
- id buf_multi_expert_data[MAX_K]; // [EXPERT_SIZE bytes] each — buffer set A
- id buf_multi_expert_data_B[MAX_K]; // [EXPERT_SIZE bytes] each — buffer set B (prefetch)
- id buf_multi_expert_gate[MAX_K]; // [MOE_INTERMEDIATE floats]
- id buf_multi_expert_up[MAX_K]; // [MOE_INTERMEDIATE floats]
- id buf_multi_expert_act[MAX_K]; // [MOE_INTERMEDIATE floats]
- id buf_multi_expert_out[MAX_K]; // [HIDDEN_DIM floats]
- id buf_multi_expert_input; // [HIDDEN_DIM floats] (shared, read-only during dispatch)
+ id buf_multi_expert_data[MAX_K]; // [cfg.expert_size_4bit bytes] each — buffer set A
+ id buf_multi_expert_data_B[MAX_K]; // [cfg.expert_size_4bit bytes] each — buffer set B (prefetch)
+ id buf_multi_expert_gate[MAX_K]; // [cfg.moe_intermediate floats]
+ id buf_multi_expert_up[MAX_K]; // [cfg.moe_intermediate floats]
+ id buf_multi_expert_act[MAX_K]; // [cfg.moe_intermediate floats]
+ id buf_multi_expert_out[MAX_K]; // [cfg.hidden_dim floats]
+ id buf_multi_expert_input; // [cfg.hidden_dim floats] (shared, read-only during dispatch)
// Shared expert buffers for fused CMD2 (shared gate/up computed in CMD1,
// SwiGLU + down_proj in CMD2 alongside routed experts)
- id buf_shared_gate; // [SHARED_INTERMEDIATE floats]
- id buf_shared_up; // [SHARED_INTERMEDIATE floats]
- id buf_shared_act; // [SHARED_INTERMEDIATE floats] (SwiGLU output)
- id buf_shared_out; // [HIDDEN_DIM floats] (down_proj output)
+ id buf_shared_gate; // [cfg.shared_intermediate floats]
+ id buf_shared_up; // [cfg.shared_intermediate floats]
+ id buf_shared_act; // [cfg.shared_intermediate floats] (SwiGLU output)
+ id buf_shared_out; // [cfg.hidden_dim floats] (down_proj output)
// Fused o_proj+norm+routing buffers (eliminates 1 cmd buffer per layer)
- id buf_residual; // [HIDDEN_DIM floats] holds residual for GPU add
- id buf_h_mid; // [HIDDEN_DIM floats] residual+oproj result
+ id buf_residual; // [cfg.hidden_dim floats] holds residual for GPU add
+ id buf_h_mid; // [cfg.hidden_dim floats] residual+oproj result
id buf_sum_sq; // [1 float] for RMS norm reduction
// GPU attention buffers (for full attention layers)
- #define NUM_FULL_ATTN_LAYERS 15
- id buf_kv_k[NUM_FULL_ATTN_LAYERS]; // K cache per full-attn layer
- id buf_kv_v[NUM_FULL_ATTN_LAYERS]; // V cache per full-attn layer
- id buf_attn_q; // [NUM_ATTN_HEADS * HEAD_DIM floats] all query heads
- id buf_attn_scores; // [NUM_ATTN_HEADS * MAX_SEQ_LEN floats] all heads' scores
- id buf_attn_out; // [NUM_ATTN_HEADS * HEAD_DIM floats] full attention output
- id buf_attn_gate; // [NUM_ATTN_HEADS * HEAD_DIM floats] sigmoid gate
+ id __strong *buf_kv_k; // K cache per full-attn layer
+ id __strong *buf_kv_v; // V cache per full-attn layer
+ id buf_attn_q; // [cfg.num_attn_heads * cfg.head_dim floats] all query heads
+ id buf_attn_scores; // [cfg.num_attn_heads * cfg.max_seq_len floats] all heads' scores
+ id buf_attn_out; // [cfg.num_attn_heads * cfg.head_dim floats] full attention output
+ id buf_attn_gate; // [cfg.num_attn_heads * cfg.head_dim floats] sigmoid gate
// CMD3 GPU-side combine buffers (weighted_sum + residual + norm on GPU)
id moe_combine_residual; // fused combine kernel
- id buf_moe_hidden; // [HIDDEN_DIM floats] GPU combine output (hidden state)
+ id buf_moe_hidden; // [cfg.hidden_dim floats] GPU combine output (hidden state)
id buf_combine_params; // [10 floats] expert weights[8] + shared_gate_score + padding
id buf_cmd3_sum_sq; // [1 float] for RMS norm reduction in CMD3
// Shared event for CPU-GPU synchronization (async pipeline)
@@ -975,24 +1275,28 @@ static void cpu_conv1d_step(
id compute_decay_beta; // g_decay and beta_gate for delta-net
id gated_rms_norm; // z-gated output normalization
// Persistent GPU state buffers for linear attention layers
- #define NUM_LINEAR_LAYERS 45
- id buf_delta_state[NUM_LINEAR_LAYERS]; // [64*128*128] float per layer
- id buf_conv_state[NUM_LINEAR_LAYERS]; // [3*12288] float per layer
+ id __strong *buf_delta_state; // [v_heads*v_dim*k_dim] float per layer
+ id __strong *buf_conv_state; // [(kernel-1)*conv_dim] float per layer
// Scratch buffers for delta-net inputs/outputs
- id buf_delta_q; // [2048] float
- id buf_delta_k; // [2048] float
- id buf_delta_v; // [8192] float
- id buf_delta_g_decay; // [64] float
- id buf_delta_beta; // [64] float
- id buf_delta_output; // [8192] float
- id buf_conv_input; // [12288] float
- id buf_conv_output; // [12288] float
+ id buf_delta_q; // [cfg.linear_total_key=2048] float
+ id buf_delta_k; // [cfg.linear_total_key=2048] float
+ id buf_delta_v; // [cfg.linear_total_value=4096] float
+ id buf_delta_g_decay; // [cfg.linear_num_v_heads=32] float
+ id buf_delta_beta; // [cfg.linear_num_v_heads=32] float
+ id buf_delta_output; // [cfg.linear_total_value=4096] float
+ id buf_conv_input; // [cfg.linear_conv_dim=8192] float
+ id buf_conv_output; // [cfg.linear_conv_dim=8192] float
} MetalCtx;
static MetalCtx *g_metal = NULL;
static MetalCtx *metal_setup(void) {
MetalCtx *ctx = calloc(1, sizeof(MetalCtx));
+ // Allocate dynamic buffer arrays based on config
+ ctx->buf_kv_k = (__strong id *)calloc(cfg.num_full_attn_layers, sizeof(id));
+ ctx->buf_kv_v = (__strong id *)calloc(cfg.num_full_attn_layers, sizeof(id));
+ ctx->buf_delta_state = (__strong id *)calloc(cfg.num_linear_layers, sizeof(id));
+ ctx->buf_conv_state = (__strong id *)calloc(cfg.num_linear_layers, sizeof(id));
ctx->device = MTLCreateSystemDefaultDevice();
if (!ctx->device) {
fprintf(stderr, "ERROR: No Metal device\n");
@@ -1075,10 +1379,10 @@ static void cpu_conv1d_step(
// Allocate reusable buffers (large enough for biggest projection)
// Q proj output is 16384 floats, lm_head output is 248320 floats
// o_proj input is 8192, linear attn out_proj input is 8192
- size_t max_out = VOCAB_SIZE * sizeof(float); // lm_head is largest
- size_t max_in = LINEAR_TOTAL_VALUE * sizeof(float); // 8192 floats (linear_attn out_proj)
- if (max_in < (size_t)(NUM_ATTN_HEADS * HEAD_DIM) * sizeof(float)) {
- max_in = (size_t)(NUM_ATTN_HEADS * HEAD_DIM) * sizeof(float); // o_proj input = 8192
+ size_t max_out = cfg.vocab_size * sizeof(float); // lm_head is largest
+ size_t max_in = cfg.linear_total_value * sizeof(float); // 8192 floats (linear_attn out_proj)
+ if (max_in < (size_t)(cfg.num_attn_heads * cfg.head_dim) * sizeof(float)) {
+ max_in = (size_t)(cfg.num_attn_heads * cfg.head_dim) * sizeof(float); // o_proj input = 8192
}
ctx->buf_input = [ctx->device newBufferWithLength:max_in options:MTLResourceStorageModeShared];
ctx->buf_output = [ctx->device newBufferWithLength:max_out options:MTLResourceStorageModeShared];
@@ -1087,9 +1391,9 @@ static void cpu_conv1d_step(
// q_proj = 16384 floats, qkv_proj = 12288, z_proj = 8192, o_proj = 4096
// lm_head (248320) uses buf_output directly, not batched.
{
- size_t slot_size = (size_t)(NUM_ATTN_HEADS * HEAD_DIM * 2) * sizeof(float); // 16384 floats
- if (slot_size < (size_t)LINEAR_CONV_DIM * sizeof(float))
- slot_size = (size_t)LINEAR_CONV_DIM * sizeof(float); // 12288 floats
+ size_t slot_size = (size_t)(cfg.num_attn_heads * cfg.head_dim * 2) * sizeof(float); // 16384 floats
+ if (slot_size < (size_t)cfg.linear_conv_dim * sizeof(float))
+ slot_size = (size_t)cfg.linear_conv_dim * sizeof(float); // 12288 floats
for (int i = 0; i < MAX_BATCH_SLOTS; i++) {
ctx->batch_out[i] = [ctx->device newBufferWithLength:slot_size
options:MTLResourceStorageModeShared];
@@ -1097,25 +1401,25 @@ static void cpu_conv1d_step(
}
// Expert computation buffers (reused across all experts and layers)
- ctx->buf_expert_data = [ctx->device newBufferWithLength:EXPERT_SIZE
+ ctx->buf_expert_data = [ctx->device newBufferWithLength:cfg.expert_size_4bit
options:MTLResourceStorageModeShared];
- ctx->buf_expert_input = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_expert_input = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_expert_gate = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float)
+ ctx->buf_expert_gate = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_expert_up = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float)
+ ctx->buf_expert_up = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_expert_act = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float)
+ ctx->buf_expert_act = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_expert_out = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_expert_out = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
// Multi-expert buffers: K independent slots (double-buffered data)
// Expert data buffers use 2MB-aligned backing memory for DMA efficiency.
// The pread DMA controller transfers 3.6x faster with 2MB alignment vs 16KB.
- ctx->buf_multi_expert_input = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_multi_expert_input = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
- size_t expert_alloc_size = (EXPERT_SIZE + 2*1024*1024 - 1) & ~(2*1024*1024 - 1); // round up to 2MB
+ size_t expert_alloc_size = (cfg.expert_size_4bit + 2*1024*1024 - 1) & ~(2*1024*1024 - 1); // round up to 2MB
for (int k = 0; k < MAX_K; k++) {
// 2MB-aligned allocation for optimal DMA throughput
void *aligned_data = NULL, *aligned_data_b = NULL;
@@ -1131,36 +1435,36 @@ static void cpu_conv1d_step(
length:expert_alloc_size
options:MTLResourceStorageModeShared
deallocator:nil];
- ctx->buf_multi_expert_gate[k] = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float)
+ ctx->buf_multi_expert_gate[k] = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_multi_expert_up[k] = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float)
+ ctx->buf_multi_expert_up[k] = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_multi_expert_act[k] = [ctx->device newBufferWithLength:MOE_INTERMEDIATE * sizeof(float)
+ ctx->buf_multi_expert_act[k] = [ctx->device newBufferWithLength:cfg.moe_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_multi_expert_out[k] = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_multi_expert_out[k] = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
}
// Shared expert buffers (for fused CMD2)
- ctx->buf_shared_gate = [ctx->device newBufferWithLength:SHARED_INTERMEDIATE * sizeof(float)
+ ctx->buf_shared_gate = [ctx->device newBufferWithLength:cfg.shared_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_shared_up = [ctx->device newBufferWithLength:SHARED_INTERMEDIATE * sizeof(float)
+ ctx->buf_shared_up = [ctx->device newBufferWithLength:cfg.shared_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_shared_act = [ctx->device newBufferWithLength:SHARED_INTERMEDIATE * sizeof(float)
+ ctx->buf_shared_act = [ctx->device newBufferWithLength:cfg.shared_intermediate * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_shared_out = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_shared_out = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
// Fused o_proj+norm+routing buffers
- ctx->buf_residual = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_residual = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_h_mid = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_h_mid = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
ctx->buf_sum_sq = [ctx->device newBufferWithLength:sizeof(float)
options:MTLResourceStorageModeShared];
// CMD3 GPU-side combine buffers
- ctx->buf_moe_hidden = [ctx->device newBufferWithLength:HIDDEN_DIM * sizeof(float)
+ ctx->buf_moe_hidden = [ctx->device newBufferWithLength:cfg.hidden_dim * sizeof(float)
options:MTLResourceStorageModeShared];
ctx->buf_combine_params = [ctx->device newBufferWithLength:10 * sizeof(float)
options:MTLResourceStorageModeShared];
@@ -1169,50 +1473,52 @@ static void cpu_conv1d_step(
// GPU attention buffers
{
- size_t kv_dim = NUM_KV_HEADS * HEAD_DIM; // 512
+ size_t kv_dim = cfg.num_kv_heads * cfg.head_dim; // 512
size_t kv_cache_size = GPU_KV_SEQ * kv_dim * sizeof(float);
- for (int i = 0; i < NUM_FULL_ATTN_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_full_attn_layers; i++) {
ctx->buf_kv_k[i] = [ctx->device newBufferWithLength:kv_cache_size
options:MTLResourceStorageModeShared];
ctx->buf_kv_v[i] = [ctx->device newBufferWithLength:kv_cache_size
options:MTLResourceStorageModeShared];
}
- ctx->buf_attn_q = [ctx->device newBufferWithLength:NUM_ATTN_HEADS * HEAD_DIM * sizeof(float)
+ ctx->buf_attn_q = [ctx->device newBufferWithLength:cfg.num_attn_heads * cfg.head_dim * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_attn_scores = [ctx->device newBufferWithLength:(size_t)NUM_ATTN_HEADS * GPU_KV_SEQ * sizeof(float)
+ ctx->buf_attn_scores = [ctx->device newBufferWithLength:(size_t)cfg.num_attn_heads * GPU_KV_SEQ * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_attn_out = [ctx->device newBufferWithLength:NUM_ATTN_HEADS * HEAD_DIM * sizeof(float)
+ ctx->buf_attn_out = [ctx->device newBufferWithLength:cfg.num_attn_heads * cfg.head_dim * sizeof(float)
options:MTLResourceStorageModeShared];
- ctx->buf_attn_gate = [ctx->device newBufferWithLength:NUM_ATTN_HEADS * HEAD_DIM * sizeof(float)
+ ctx->buf_attn_gate = [ctx->device newBufferWithLength:cfg.num_attn_heads * cfg.head_dim * sizeof(float)
options:MTLResourceStorageModeShared];
printf("[metal] GPU attention buffers: %d KV caches (%.1f MB each), scores buf %.1f MB\n",
- NUM_FULL_ATTN_LAYERS, kv_cache_size / 1e6,
- (double)(NUM_ATTN_HEADS * MAX_SEQ_LEN * sizeof(float)) / 1e6);
+ cfg.num_full_attn_layers, kv_cache_size / 1e6,
+ (double)(cfg.num_attn_heads * cfg.max_seq_len * sizeof(float)) / 1e6);
}
// Persistent GPU state buffers for delta-net (linear attention layers)
if (ctx->delta_net_step) {
- for (int i = 0; i < NUM_LINEAR_LAYERS; i++) {
- ctx->buf_delta_state[i] = [ctx->device newBufferWithLength:64*128*128*sizeof(float)
+ for (int i = 0; i < cfg.num_linear_layers; i++) {
+ ctx->buf_delta_state[i] = [ctx->device newBufferWithLength:(size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float)
options:MTLResourceStorageModeShared];
- memset([ctx->buf_delta_state[i] contents], 0, 64*128*128*sizeof(float));
- ctx->buf_conv_state[i] = [ctx->device newBufferWithLength:3*12288*sizeof(float)
+ memset([ctx->buf_delta_state[i] contents], 0, (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float));
+ ctx->buf_conv_state[i] = [ctx->device newBufferWithLength:(cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float)
options:MTLResourceStorageModeShared];
- memset([ctx->buf_conv_state[i] contents], 0, 3*12288*sizeof(float));
+ memset([ctx->buf_conv_state[i] contents], 0, (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float));
}
// Scratch buffers for delta-net inputs/outputs (allocated once, reused)
- ctx->buf_delta_q = [ctx->device newBufferWithLength:2048*sizeof(float) options:MTLResourceStorageModeShared];
- ctx->buf_delta_k = [ctx->device newBufferWithLength:2048*sizeof(float) options:MTLResourceStorageModeShared];
- ctx->buf_delta_v = [ctx->device newBufferWithLength:8192*sizeof(float) options:MTLResourceStorageModeShared];
- ctx->buf_delta_g_decay = [ctx->device newBufferWithLength:64*sizeof(float) options:MTLResourceStorageModeShared];
- ctx->buf_delta_beta = [ctx->device newBufferWithLength:64*sizeof(float) options:MTLResourceStorageModeShared];
- ctx->buf_delta_output = [ctx->device newBufferWithLength:8192*sizeof(float) options:MTLResourceStorageModeShared];
- ctx->buf_conv_input = [ctx->device newBufferWithLength:12288*sizeof(float) options:MTLResourceStorageModeShared];
- ctx->buf_conv_output = [ctx->device newBufferWithLength:12288*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_delta_q = [ctx->device newBufferWithLength:cfg.linear_total_key*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_delta_k = [ctx->device newBufferWithLength:cfg.linear_total_key*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_delta_v = [ctx->device newBufferWithLength:cfg.linear_total_value*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_delta_g_decay = [ctx->device newBufferWithLength:cfg.linear_num_v_heads*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_delta_beta = [ctx->device newBufferWithLength:cfg.linear_num_v_heads*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_delta_output = [ctx->device newBufferWithLength:cfg.linear_total_value*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_conv_input = [ctx->device newBufferWithLength:cfg.linear_conv_dim*sizeof(float) options:MTLResourceStorageModeShared];
+ ctx->buf_conv_output = [ctx->device newBufferWithLength:cfg.linear_conv_dim*sizeof(float) options:MTLResourceStorageModeShared];
+ size_t state_bytes = (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float);
+ size_t conv_bytes = (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float);
printf("[metal] Delta-net GPU buffers: %d layers (%.1f MB state + %.1f MB scratch)\n",
- NUM_LINEAR_LAYERS,
- NUM_LINEAR_LAYERS * (64*128*128*4 + 3*12288*4) / 1e6,
- (2048+2048+8192+64+64+8192+12288+12288) * 4 / 1e6);
+ cfg.num_linear_layers,
+ cfg.num_linear_layers * (state_bytes + conv_bytes) / 1e6,
+ (cfg.linear_total_key*2+cfg.linear_total_value*2+cfg.linear_num_v_heads*2+cfg.linear_conv_dim*2) * sizeof(float) / 1e6);
}
// Create shared event for CPU-GPU async pipeline
@@ -1226,11 +1532,11 @@ static void cpu_conv1d_step(
// Reset delta-net and conv GPU state buffers (call at start of new generation)
static void reset_delta_net_state(void) {
if (!g_metal || !g_metal->delta_net_step) return;
- for (int i = 0; i < NUM_LINEAR_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_linear_layers; i++) {
if (g_metal->buf_delta_state[i])
- memset([g_metal->buf_delta_state[i] contents], 0, 64*128*128*sizeof(float));
+ memset([g_metal->buf_delta_state[i] contents], 0, (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float));
if (g_metal->buf_conv_state[i])
- memset([g_metal->buf_conv_state[i] contents], 0, 3*12288*sizeof(float));
+ memset([g_metal->buf_conv_state[i] contents], 0, (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float));
}
}
@@ -1508,21 +1814,21 @@ static void gpu_encode_expert_forward_slot(
NSUInteger up_w_off, up_s_off, up_b_off;
NSUInteger down_w_off, down_s_off, down_b_off;
if (g_use_2bit) {
- gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2;
- up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2;
- down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2;
+ gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2;
+ up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2;
+ down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2;
} else {
- gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224;
- up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520;
- down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816;
+ gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4;
+ up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4;
+ down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4;
}
id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3;
- uint32_t gate_up_out = MOE_INTERMEDIATE;
- uint32_t gate_up_in = HIDDEN_DIM;
- uint32_t down_out = HIDDEN_DIM;
- uint32_t down_in = MOE_INTERMEDIATE;
- uint32_t gs = GROUP_SIZE;
+ uint32_t gate_up_out = cfg.moe_intermediate;
+ uint32_t gate_up_in = cfg.hidden_dim;
+ uint32_t down_out = cfg.hidden_dim;
+ uint32_t down_in = cfg.moe_intermediate;
+ uint32_t gs = cfg.group_size;
// gate_proj: data[k] -> gate[k]
{
@@ -1604,21 +1910,21 @@ static void gpu_encode_expert_forward_slot_buf(
NSUInteger up_w_off, up_s_off, up_b_off;
NSUInteger down_w_off, down_s_off, down_b_off;
if (g_use_2bit) {
- gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2;
- up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2;
- down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2;
+ gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2;
+ up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2;
+ down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2;
} else {
- gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224;
- up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520;
- down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816;
+ gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4;
+ up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4;
+ down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4;
}
id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3;
- uint32_t gate_up_out = MOE_INTERMEDIATE;
- uint32_t gate_up_in = HIDDEN_DIM;
- uint32_t down_out = HIDDEN_DIM;
- uint32_t down_in = MOE_INTERMEDIATE;
- uint32_t gs = GROUP_SIZE;
+ uint32_t gate_up_out = cfg.moe_intermediate;
+ uint32_t gate_up_in = cfg.hidden_dim;
+ uint32_t down_out = cfg.hidden_dim;
+ uint32_t down_in = cfg.moe_intermediate;
+ uint32_t gs = cfg.group_size;
// gate_proj
{
@@ -1704,21 +2010,21 @@ static void gpu_encode_experts_batched(
NSUInteger up_w_off, up_s_off, up_b_off;
NSUInteger down_w_off, down_s_off, down_b_off;
if (g_use_2bit) {
- gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2;
- up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2;
- down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2;
+ gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2;
+ up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2;
+ down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2;
} else {
- gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224;
- up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520;
- down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816;
+ gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4;
+ up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4;
+ down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4;
}
id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3;
- uint32_t gate_up_out = MOE_INTERMEDIATE;
- uint32_t gate_up_in = HIDDEN_DIM;
- uint32_t down_out = HIDDEN_DIM;
- uint32_t down_in = MOE_INTERMEDIATE;
- uint32_t gs = GROUP_SIZE;
+ uint32_t gate_up_out = cfg.moe_intermediate;
+ uint32_t gate_up_in = cfg.hidden_dim;
+ uint32_t down_out = cfg.hidden_dim;
+ uint32_t down_in = cfg.moe_intermediate;
+ uint32_t gs = cfg.group_size;
// 2-bit: packed_cols = in_dim/16, threadgroups = out_dim/8
// 4-bit: packed_cols = in_dim/8, threadgroups = out_dim/8
// Threadgroup count is the same (based on out_dim), kernel handles packed_cols internally.
@@ -1793,21 +2099,21 @@ static void gpu_encode_expert_forward(
MetalCtx *ctx,
id cmdbuf
) {
- NSUInteger gate_w_off = 0;
- NSUInteger gate_s_off = 2097152;
- NSUInteger gate_b_off = 2228224;
- NSUInteger up_w_off = 2359296;
- NSUInteger up_s_off = 4456448;
- NSUInteger up_b_off = 4587520;
- NSUInteger down_w_off = 4718592;
- NSUInteger down_s_off = 6815744;
- NSUInteger down_b_off = 6946816;
-
- uint32_t gate_up_out = MOE_INTERMEDIATE;
- uint32_t gate_up_in = HIDDEN_DIM;
- uint32_t down_out = HIDDEN_DIM;
- uint32_t down_in = MOE_INTERMEDIATE;
- uint32_t gs = GROUP_SIZE;
+ NSUInteger gate_w_off = cfg.gate_w_off_4;
+ NSUInteger gate_s_off = cfg.gate_s_off_4;
+ NSUInteger gate_b_off = cfg.gate_b_off_4;
+ NSUInteger up_w_off = cfg.up_w_off_4;
+ NSUInteger up_s_off = cfg.up_s_off_4;
+ NSUInteger up_b_off = cfg.up_b_off_4;
+ NSUInteger down_w_off = cfg.down_w_off_4;
+ NSUInteger down_s_off = cfg.down_s_off_4;
+ NSUInteger down_b_off = cfg.down_b_off_4;
+
+ uint32_t gate_up_out = cfg.moe_intermediate;
+ uint32_t gate_up_in = cfg.hidden_dim;
+ uint32_t down_out = cfg.hidden_dim;
+ uint32_t down_in = cfg.moe_intermediate;
+ uint32_t gs = cfg.group_size;
// gate_proj
{
@@ -1903,9 +2209,9 @@ static void fast_batch_matvec(
__attribute__((unused))
static void gpu_expert_forward(
MetalCtx *ctx,
- const void *expert_data, // EXPERT_SIZE bytes (may be buf_expert_data contents)
- const float *h_post, // [HIDDEN_DIM] input
- float *expert_out, // [HIDDEN_DIM] output
+ const void *expert_data, // cfg.expert_size_4bit bytes (may be buf_expert_data contents)
+ const float *h_post, // [cfg.hidden_dim] input
+ float *expert_out, // [cfg.hidden_dim] output
int expert_data_already_in_buffer
) {
// Expert layout offsets — select based on quantization mode
@@ -1913,13 +2219,13 @@ static void gpu_expert_forward(
NSUInteger up_w_off, up_s_off, up_b_off;
NSUInteger down_w_off, down_s_off, down_b_off;
if (g_use_2bit) {
- gate_w_off = GATE_W_OFF_2; gate_s_off = GATE_S_OFF_2; gate_b_off = GATE_B_OFF_2;
- up_w_off = UP_W_OFF_2; up_s_off = UP_S_OFF_2; up_b_off = UP_B_OFF_2;
- down_w_off = DOWN_W_OFF_2; down_s_off = DOWN_S_OFF_2; down_b_off = DOWN_B_OFF_2;
+ gate_w_off = cfg.gate_w_off_2; gate_s_off = cfg.gate_s_off_2; gate_b_off = cfg.gate_b_off_2;
+ up_w_off = cfg.up_w_off_2; up_s_off = cfg.up_s_off_2; up_b_off = cfg.up_b_off_2;
+ down_w_off = cfg.down_w_off_2; down_s_off = cfg.down_s_off_2; down_b_off = cfg.down_b_off_2;
} else {
- gate_w_off = 0; gate_s_off = 2097152; gate_b_off = 2228224;
- up_w_off = 2359296; up_s_off = 4456448; up_b_off = 4587520;
- down_w_off = 4718592; down_s_off = 6815744; down_b_off = 6946816;
+ gate_w_off = cfg.gate_w_off_4; gate_s_off = cfg.gate_s_off_4; gate_b_off = cfg.gate_b_off_4;
+ up_w_off = cfg.up_w_off_4; up_s_off = cfg.up_s_off_4; up_b_off = cfg.up_b_off_4;
+ down_w_off = cfg.down_w_off_4; down_s_off = cfg.down_s_off_4; down_b_off = cfg.down_b_off_4;
}
id expert_pipe = g_use_2bit ? ctx->matvec_2bit : ctx->matvec_v3;
@@ -1927,13 +2233,13 @@ static void gpu_expert_forward(
if (!expert_data_already_in_buffer) {
memcpy([ctx->buf_expert_data contents], expert_data, active_expert_size());
}
- memcpy([ctx->buf_expert_input contents], h_post, HIDDEN_DIM * sizeof(float));
+ memcpy([ctx->buf_expert_input contents], h_post, cfg.hidden_dim * sizeof(float));
- uint32_t gate_up_out = MOE_INTERMEDIATE; // 1024
- uint32_t gate_up_in = HIDDEN_DIM; // 4096
- uint32_t down_out = HIDDEN_DIM; // 4096
- uint32_t down_in = MOE_INTERMEDIATE; // 1024
- uint32_t gs = GROUP_SIZE; // 64
+ uint32_t gate_up_out = cfg.moe_intermediate; // 1024
+ uint32_t gate_up_in = cfg.hidden_dim; // 4096
+ uint32_t down_out = cfg.hidden_dim; // 4096
+ uint32_t down_in = cfg.moe_intermediate; // 1024
+ uint32_t gs = cfg.group_size; // 64
// Build one command buffer with all 4 dispatches:
// 1. gate_proj matvec (h_post -> gate_out)
@@ -2015,7 +2321,7 @@ static void gpu_expert_forward(
[cmdbuf waitUntilCompleted];
// Copy result back to CPU
- memcpy(expert_out, [ctx->buf_expert_out contents], HIDDEN_DIM * sizeof(float));
+ memcpy(expert_out, [ctx->buf_expert_out contents], cfg.hidden_dim * sizeof(float));
}
// ============================================================================
@@ -2031,7 +2337,7 @@ static void apply_rotary_emb(float *q, float *k, int pos, int num_heads, int num
for (int h = 0; h < num_heads; h++) {
float *qh = q + h * head_dim;
for (int i = 0; i < half; i++) {
- float freq = 1.0f / powf(ROPE_THETA, (float)(2 * i) / rotary_dim);
+ float freq = 1.0f / powf(cfg.rope_theta, (float)(2 * i) / rotary_dim);
float angle = (float)pos * freq;
float cos_a = cosf(angle);
float sin_a = sinf(angle);
@@ -2045,7 +2351,7 @@ static void apply_rotary_emb(float *q, float *k, int pos, int num_heads, int num
for (int h = 0; h < num_kv_heads; h++) {
float *kh = k + h * head_dim;
for (int i = 0; i < half; i++) {
- float freq = 1.0f / powf(ROPE_THETA, (float)(2 * i) / rotary_dim);
+ float freq = 1.0f / powf(cfg.rope_theta, (float)(2 * i) / rotary_dim);
float angle = (float)pos * freq;
float cos_a = cosf(angle);
float sin_a = sinf(angle);
@@ -2070,8 +2376,8 @@ static void apply_rotary_emb(float *q, float *k, int pos, int num_heads, int num
static KVCache *kv_cache_new(void) {
KVCache *c = calloc(1, sizeof(KVCache));
- c->k_cache = calloc(MAX_SEQ_LEN * NUM_KV_HEADS * HEAD_DIM, sizeof(float));
- c->v_cache = calloc(MAX_SEQ_LEN * NUM_KV_HEADS * HEAD_DIM, sizeof(float));
+ c->k_cache = calloc(cfg.max_seq_len * cfg.num_kv_heads * cfg.head_dim, sizeof(float));
+ c->v_cache = calloc(cfg.max_seq_len * cfg.num_kv_heads * cfg.head_dim, sizeof(float));
c->len = 0;
return c;
}
@@ -2095,8 +2401,8 @@ static void kv_cache_free(KVCache *c) {
static LinearAttnState *linear_attn_state_new(void) {
LinearAttnState *s = calloc(1, sizeof(LinearAttnState));
- s->conv_state = calloc((CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM, sizeof(float));
- s->ssm_state = calloc(LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM, sizeof(float));
+ s->conv_state = calloc((cfg.conv_kernel_size - 1) * cfg.linear_conv_dim, sizeof(float));
+ s->ssm_state = calloc(cfg.linear_num_v_heads * cfg.linear_value_dim * cfg.linear_key_dim, sizeof(float));
return s;
}
@@ -2124,7 +2430,7 @@ static float vec_rms(const float *v, int n) {
static void full_attention_forward(
WeightFile *wf,
int layer_idx,
- float *hidden, // [HIDDEN_DIM] in/out
+ float *hidden, // [cfg.hidden_dim] in/out
KVCache *kv,
int pos // position in sequence
) {
@@ -2132,32 +2438,32 @@ static void full_attention_forward(
int do_debug = 0; // set to (fa_debug_count <= N) to enable debug
char name[256];
- float *normed = malloc(HIDDEN_DIM * sizeof(float));
- float *residual = malloc(HIDDEN_DIM * sizeof(float));
- cpu_vec_copy(residual, hidden, HIDDEN_DIM);
+ float *normed = malloc(cfg.hidden_dim * sizeof(float));
+ float *residual = malloc(cfg.hidden_dim * sizeof(float));
+ cpu_vec_copy(residual, hidden, cfg.hidden_dim);
if (do_debug) {
fprintf(stderr, "[FA-DBG] layer=%d pos=%d hidden_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- layer_idx, pos, vec_rms(hidden, HIDDEN_DIM),
+ layer_idx, pos, vec_rms(hidden, cfg.hidden_dim),
hidden[0], hidden[1], hidden[2], hidden[3], hidden[4]);
}
// ---- Input LayerNorm ----
snprintf(name, sizeof(name), "model.layers.%d.input_layernorm.weight", layer_idx);
uint16_t *norm_w = get_tensor_ptr(wf, name);
- cpu_rms_norm(hidden, norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS);
+ cpu_rms_norm(hidden, norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps);
if (do_debug) {
fprintf(stderr, "[FA-DBG] normed_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- vec_rms(normed, HIDDEN_DIM), normed[0], normed[1], normed[2], normed[3], normed[4]);
+ vec_rms(normed, cfg.hidden_dim), normed[0], normed[1], normed[2], normed[3], normed[4]);
}
// ---- QKV Projection ----
// CRITICAL: Q projection outputs num_heads * head_dim * 2 = 16384
// The second half is a sigmoid gate applied after attention
- int q_proj_dim = NUM_ATTN_HEADS * HEAD_DIM * 2; // 32 * 256 * 2 = 16384
- int q_dim = NUM_ATTN_HEADS * HEAD_DIM; // 32 * 256 = 8192
- int kv_dim = NUM_KV_HEADS * HEAD_DIM; // 2 * 256 = 512
+ int q_proj_dim = cfg.num_attn_heads * cfg.head_dim * 2; // 32 * 256 * 2 = 16384
+ int q_dim = cfg.num_attn_heads * cfg.head_dim; // 32 * 256 = 8192
+ int kv_dim = cfg.num_kv_heads * cfg.head_dim; // 2 * 256 = 512
float *q_proj_out = calloc(q_proj_dim, sizeof(float));
float *k = calloc(kv_dim, sizeof(float));
@@ -2188,11 +2494,11 @@ static void full_attention_forward(
// Batch Q/K/V into one command buffer (3 dispatches, 1 commit)
if (qw && qs && qb && kw && ks && kb && vw && vs && vb) {
BatchMatvecSpec qkv_specs[3] = {
- { qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 },
- { kw, ks, kb, k, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 },
- { vw, vs, vb, v, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 },
+ { qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, cfg.hidden_dim, cfg.group_size, 0 },
+ { kw, ks, kb, k, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 1 },
+ { vw, vs, vb, v, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 2 },
};
- fast_batch_matvec(normed, HIDDEN_DIM, qkv_specs, 3);
+ fast_batch_matvec(normed, cfg.hidden_dim, qkv_specs, 3);
}
if (do_debug) {
@@ -2203,10 +2509,10 @@ static void full_attention_forward(
// Split q_proj_out into queries and gate
float *q = calloc(q_dim, sizeof(float));
float *q_gate = calloc(q_dim, sizeof(float));
- for (int h = 0; h < NUM_ATTN_HEADS; h++) {
- float *src = q_proj_out + h * (2 * HEAD_DIM);
- memcpy(q + h * HEAD_DIM, src, HEAD_DIM * sizeof(float));
- memcpy(q_gate + h * HEAD_DIM, src + HEAD_DIM, HEAD_DIM * sizeof(float));
+ for (int h = 0; h < cfg.num_attn_heads; h++) {
+ float *src = q_proj_out + h * (2 * cfg.head_dim);
+ memcpy(q + h * cfg.head_dim, src, cfg.head_dim * sizeof(float));
+ memcpy(q_gate + h * cfg.head_dim, src + cfg.head_dim, cfg.head_dim * sizeof(float));
}
free(q_proj_out);
@@ -2230,24 +2536,24 @@ static void full_attention_forward(
// Apply per-head Q norm
if (qnorm_w) {
- for (int h = 0; h < NUM_ATTN_HEADS; h++) {
- float *qh = q + h * HEAD_DIM;
+ for (int h = 0; h < cfg.num_attn_heads; h++) {
+ float *qh = q + h * cfg.head_dim;
float sum_sq = 0.0f;
- for (int i = 0; i < HEAD_DIM; i++) sum_sq += qh[i] * qh[i];
- float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS);
- for (int i = 0; i < HEAD_DIM; i++) {
+ for (int i = 0; i < cfg.head_dim; i++) sum_sq += qh[i] * qh[i];
+ float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps);
+ for (int i = 0; i < cfg.head_dim; i++) {
qh[i] = qh[i] * inv_rms * bf16_to_f32(qnorm_w[i]);
}
}
}
// Apply per-head K norm
if (knorm_w) {
- for (int h = 0; h < NUM_KV_HEADS; h++) {
- float *kh = k + h * HEAD_DIM;
+ for (int h = 0; h < cfg.num_kv_heads; h++) {
+ float *kh = k + h * cfg.head_dim;
float sum_sq = 0.0f;
- for (int i = 0; i < HEAD_DIM; i++) sum_sq += kh[i] * kh[i];
- float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS);
- for (int i = 0; i < HEAD_DIM; i++) {
+ for (int i = 0; i < cfg.head_dim; i++) sum_sq += kh[i] * kh[i];
+ float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps);
+ for (int i = 0; i < cfg.head_dim; i++) {
kh[i] = kh[i] * inv_rms * bf16_to_f32(knorm_w[i]);
}
}
@@ -2255,7 +2561,7 @@ static void full_attention_forward(
// ---- RoPE ----
- apply_rotary_emb(q, k, pos, NUM_ATTN_HEADS, NUM_KV_HEADS, HEAD_DIM, ROTARY_DIM);
+ apply_rotary_emb(q, k, pos, cfg.num_attn_heads, cfg.num_kv_heads, cfg.head_dim, cfg.rotary_dim);
// ---- Update KV cache ----
int cache_pos = kv->len;
@@ -2264,23 +2570,23 @@ static void full_attention_forward(
kv->len++;
// ---- Scaled dot-product attention ----
- // GQA: NUM_ATTN_HEADS=32 heads, NUM_KV_HEADS=2 kv heads
+ // GQA: cfg.num_attn_heads=32 heads, cfg.num_kv_heads=2 kv heads
// Each group of 16 query heads shares 1 kv head
- int heads_per_kv = NUM_ATTN_HEADS / NUM_KV_HEADS;
- float scale = 1.0f / sqrtf((float)HEAD_DIM);
+ int heads_per_kv = cfg.num_attn_heads / cfg.num_kv_heads;
+ float scale = 1.0f / sqrtf((float)cfg.head_dim);
float *attn_out = calloc(q_dim, sizeof(float));
- for (int h = 0; h < NUM_ATTN_HEADS; h++) {
+ for (int h = 0; h < cfg.num_attn_heads; h++) {
int kv_h = h / heads_per_kv;
- float *qh = q + h * HEAD_DIM;
+ float *qh = q + h * cfg.head_dim;
// Compute attention scores for all cached positions
float *scores = malloc(kv->len * sizeof(float));
for (int p = 0; p < kv->len; p++) {
- float *kp = kv->k_cache + p * kv_dim + kv_h * HEAD_DIM;
+ float *kp = kv->k_cache + p * kv_dim + kv_h * cfg.head_dim;
float dot = 0.0f;
- for (int d = 0; d < HEAD_DIM; d++) {
+ for (int d = 0; d < cfg.head_dim; d++) {
dot += qh[d] * kp[d];
}
scores[p] = dot * scale;
@@ -2290,10 +2596,10 @@ static void full_attention_forward(
cpu_softmax(scores, kv->len);
// Weighted sum of values
- float *oh = attn_out + h * HEAD_DIM;
+ float *oh = attn_out + h * cfg.head_dim;
for (int p = 0; p < kv->len; p++) {
- float *vp = kv->v_cache + p * kv_dim + kv_h * HEAD_DIM;
- for (int d = 0; d < HEAD_DIM; d++) {
+ float *vp = kv->v_cache + p * kv_dim + kv_h * cfg.head_dim;
+ for (int d = 0; d < cfg.head_dim; d++) {
oh[d] += scores[p] * vp[d];
}
}
@@ -2310,14 +2616,14 @@ static void full_attention_forward(
}
// ---- Output projection ----
- float *attn_projected = calloc(HIDDEN_DIM, sizeof(float));
+ float *attn_projected = calloc(cfg.hidden_dim, sizeof(float));
snprintf(name, sizeof(name), "model.layers.%d.self_attn.o_proj.weight", layer_idx);
uint32_t *ow = get_tensor_ptr(wf, name);
snprintf(name, sizeof(name), "model.layers.%d.self_attn.o_proj.scales", layer_idx);
uint16_t *os_ptr = get_tensor_ptr(wf, name);
snprintf(name, sizeof(name), "model.layers.%d.self_attn.o_proj.biases", layer_idx);
uint16_t *ob = get_tensor_ptr(wf, name);
- if (ow && os_ptr && ob) fast_dequant_matvec(ow, os_ptr, ob, attn_out, attn_projected, HIDDEN_DIM, q_dim, GROUP_SIZE);
+ if (ow && os_ptr && ob) fast_dequant_matvec(ow, os_ptr, ob, attn_out, attn_projected, cfg.hidden_dim, q_dim, cfg.group_size);
if (do_debug) {
fprintf(stderr, "[FA-DBG] attn_out_rms=%.6f o_proj first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
@@ -2326,13 +2632,13 @@ static void full_attention_forward(
}
// ---- Residual connection ----
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
hidden[i] = residual[i] + attn_projected[i];
}
if (do_debug) {
fprintf(stderr, "[FA-DBG] AFTER layer=%d hidden_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- layer_idx, vec_rms(hidden, HIDDEN_DIM),
+ layer_idx, vec_rms(hidden, cfg.hidden_dim),
hidden[0], hidden[1], hidden[2], hidden[3], hidden[4]);
}
@@ -2378,7 +2684,7 @@ static void cpu_rms_norm_gated(const float *x, const float *z, const uint16_t *w
static void linear_attention_forward(
WeightFile *wf,
int layer_idx,
- float *hidden, // [HIDDEN_DIM] in/out
+ float *hidden, // [cfg.hidden_dim] in/out
LinearAttnState *state
) {
// If bypass is enabled, just pass through (identity)
@@ -2393,27 +2699,27 @@ static void linear_attention_forward(
if (la_debug) {
fprintf(stderr, "[LA-DBG] layer=%d hidden_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- layer_idx, vec_rms(hidden, HIDDEN_DIM),
+ layer_idx, vec_rms(hidden, cfg.hidden_dim),
hidden[0], hidden[1], hidden[2], hidden[3], hidden[4]);
}
char name[256];
- float *normed = malloc(HIDDEN_DIM * sizeof(float));
- float *residual = malloc(HIDDEN_DIM * sizeof(float));
- cpu_vec_copy(residual, hidden, HIDDEN_DIM);
+ float *normed = malloc(cfg.hidden_dim * sizeof(float));
+ float *residual = malloc(cfg.hidden_dim * sizeof(float));
+ cpu_vec_copy(residual, hidden, cfg.hidden_dim);
// ---- Input LayerNorm ----
snprintf(name, sizeof(name), "model.layers.%d.input_layernorm.weight", layer_idx);
uint16_t *norm_w = get_tensor_ptr(wf, name);
- cpu_rms_norm(hidden, norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS);
+ cpu_rms_norm(hidden, norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps);
// ---- Batch QKV + Z + B + A projections (4 matmuls, 1 command buffer) ----
- int qkv_dim = LINEAR_CONV_DIM; // 12288
+ int qkv_dim = cfg.linear_conv_dim; // 12288
float *qkv = calloc(qkv_dim, sizeof(float));
- int z_dim = LINEAR_TOTAL_VALUE; // 8192
+ int z_dim = cfg.linear_total_value; // 8192
float *z = calloc(z_dim, sizeof(float));
- float *beta = calloc(LINEAR_NUM_V_HEADS, sizeof(float));
- float *alpha = calloc(LINEAR_NUM_V_HEADS, sizeof(float));
+ float *beta = calloc(cfg.linear_num_v_heads, sizeof(float));
+ float *alpha = calloc(cfg.linear_num_v_heads, sizeof(float));
snprintf(name, sizeof(name), "model.layers.%d.linear_attn.in_proj_qkv.weight", layer_idx);
uint32_t *qkv_w = get_tensor_ptr(wf, name);
@@ -2446,12 +2752,12 @@ static void linear_attention_forward(
if (qkv_w && qkv_s && qkv_b && z_w && z_s && z_b &&
b_w && b_s && b_b && a_w && a_s && a_b) {
BatchMatvecSpec la_specs[4] = {
- { qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 },
- { z_w, z_s, z_b, z, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 },
- { b_w, b_s, b_b, beta, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 },
- { a_w, a_s, a_b, alpha, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 },
+ { qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, cfg.hidden_dim, cfg.group_size, 0 },
+ { z_w, z_s, z_b, z, (uint32_t)z_dim, cfg.hidden_dim, cfg.group_size, 1 },
+ { b_w, b_s, b_b, beta, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 2 },
+ { a_w, a_s, a_b, alpha, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 3 },
};
- fast_batch_matvec(normed, HIDDEN_DIM, la_specs, 4);
+ fast_batch_matvec(normed, cfg.hidden_dim, la_specs, 4);
}
// ---- Conv1d step ----
@@ -2462,22 +2768,22 @@ static void linear_attention_forward(
float *conv_out = calloc(qkv_dim, sizeof(float));
if (conv_w) {
cpu_conv1d_step(state->conv_state, qkv, conv_w, conv_out,
- qkv_dim, CONV_KERNEL_SIZE);
+ qkv_dim, cfg.conv_kernel_size);
}
// Update conv state: shift left, append new input
memmove(state->conv_state, state->conv_state + qkv_dim,
- (CONV_KERNEL_SIZE - 2) * qkv_dim * sizeof(float));
- memcpy(state->conv_state + (CONV_KERNEL_SIZE - 2) * qkv_dim, qkv,
+ (cfg.conv_kernel_size - 2) * qkv_dim * sizeof(float));
+ memcpy(state->conv_state + (cfg.conv_kernel_size - 2) * qkv_dim, qkv,
qkv_dim * sizeof(float));
// ---- Split conv_out into q, k, v ----
// q: [num_k_heads * head_k_dim] = [2048]
// k: [num_k_heads * head_k_dim] = [2048]
// v: [num_v_heads * head_v_dim] = [8192]
- float *lin_q = conv_out; // first LINEAR_TOTAL_KEY elements
- float *lin_k = conv_out + LINEAR_TOTAL_KEY; // next LINEAR_TOTAL_KEY
- float *lin_v = conv_out + 2 * LINEAR_TOTAL_KEY; // rest = LINEAR_TOTAL_VALUE
+ float *lin_q = conv_out; // first cfg.linear_total_key elements
+ float *lin_k = conv_out + cfg.linear_total_key; // next cfg.linear_total_key
+ float *lin_v = conv_out + 2 * cfg.linear_total_key; // rest = cfg.linear_total_value
// ---- RMS normalize q and k (bare, no weights) ----
// q: scale = key_dim^(-0.5), normalize per head then scale by key_dim^(-1.0)
@@ -2485,18 +2791,18 @@ static void linear_attention_forward(
// inv_scale = k.shape[-1] ** -0.5 = head_k_dim^(-0.5) = 128^(-0.5)
// q = (inv_scale**2) * rms_norm(q) = (1/128) * rms_norm(q)
// k = inv_scale * rms_norm(k) = (1/sqrt(128)) * rms_norm(k)
- float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM);
+ float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim);
- for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) {
- float *qh = lin_q + h * LINEAR_KEY_DIM;
- cpu_rms_norm_bare(qh, qh, LINEAR_KEY_DIM, 1e-6f);
+ for (int h = 0; h < cfg.linear_num_k_heads; h++) {
+ float *qh = lin_q + h * cfg.linear_key_dim;
+ cpu_rms_norm_bare(qh, qh, cfg.linear_key_dim, 1e-6f);
float q_scale = inv_scale * inv_scale; // inv_scale^2 = 1/head_k_dim
- for (int d = 0; d < LINEAR_KEY_DIM; d++) qh[d] *= q_scale;
+ for (int d = 0; d < cfg.linear_key_dim; d++) qh[d] *= q_scale;
}
- for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) {
- float *kh = lin_k + h * LINEAR_KEY_DIM;
- cpu_rms_norm_bare(kh, kh, LINEAR_KEY_DIM, 1e-6f);
- for (int d = 0; d < LINEAR_KEY_DIM; d++) kh[d] *= inv_scale;
+ for (int h = 0; h < cfg.linear_num_k_heads; h++) {
+ float *kh = lin_k + h * cfg.linear_key_dim;
+ cpu_rms_norm_bare(kh, kh, cfg.linear_key_dim, 1e-6f);
+ for (int d = 0; d < cfg.linear_key_dim; d++) kh[d] *= inv_scale;
}
// ---- Gated delta net recurrence ----
@@ -2516,14 +2822,14 @@ static void linear_attention_forward(
snprintf(name, sizeof(name), "model.layers.%d.linear_attn.dt_bias", layer_idx);
uint16_t *dt_bias_bf16 = get_tensor_ptr(wf, name);
- float *out_values = calloc(LINEAR_TOTAL_VALUE, sizeof(float)); // [num_v_heads * head_v_dim]
+ float *out_values = calloc(cfg.linear_total_value, sizeof(float)); // [num_v_heads * head_v_dim]
- int k_heads_per_v = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS; // 64/16 = 4
+ int k_heads_per_v = cfg.linear_num_v_heads / cfg.linear_num_k_heads; // 64/16 = 4
// Precompute per-head decay (g) and beta
- float g_decay[LINEAR_NUM_V_HEADS];
- float beta_gate[LINEAR_NUM_V_HEADS];
- for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) {
+ float g_decay[cfg.linear_num_v_heads];
+ float beta_gate[cfg.linear_num_v_heads];
+ for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) {
// g = exp(-exp(A_log) * softplus(a + dt_bias))
float a_val = alpha[vh];
float dt_b = dt_bias_bf16 ? bf16_to_f32(dt_bias_bf16[vh]) : 0.0f;
@@ -2535,45 +2841,45 @@ static void linear_attention_forward(
beta_gate[vh] = cpu_sigmoid(beta[vh]);
}
- for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) {
+ for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) {
int kh = vh / k_heads_per_v; // which k head this v head maps to
float g = g_decay[vh];
float b_gate = beta_gate[vh];
// state is [head_v_dim, head_k_dim]
- float *S = state->ssm_state + vh * LINEAR_VALUE_DIM * LINEAR_KEY_DIM;
- float *v_h = lin_v + vh * LINEAR_VALUE_DIM;
- float *k_h = lin_k + kh * LINEAR_KEY_DIM;
+ float *S = state->ssm_state + vh * cfg.linear_value_dim * cfg.linear_key_dim;
+ float *v_h = lin_v + vh * cfg.linear_value_dim;
+ float *k_h = lin_k + kh * cfg.linear_key_dim;
// Step 1: Decay state
- for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) {
- for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) {
- S[vi * LINEAR_KEY_DIM + ki] *= g;
+ for (int vi = 0; vi < cfg.linear_value_dim; vi++) {
+ for (int ki = 0; ki < cfg.linear_key_dim; ki++) {
+ S[vi * cfg.linear_key_dim + ki] *= g;
}
}
// Step 2: Compute kv_mem[vi] = sum_ki(S[vi,ki] * k[ki])
// Then delta[vi] = (v[vi] - kv_mem[vi]) * beta
// Then state[vi,ki] += k[ki] * delta[vi]
- for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) {
+ for (int vi = 0; vi < cfg.linear_value_dim; vi++) {
float kv_mem = 0.0f;
- for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) {
- kv_mem += S[vi * LINEAR_KEY_DIM + ki] * k_h[ki];
+ for (int ki = 0; ki < cfg.linear_key_dim; ki++) {
+ kv_mem += S[vi * cfg.linear_key_dim + ki] * k_h[ki];
}
float delta = (v_h[vi] - kv_mem) * b_gate;
- for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) {
- S[vi * LINEAR_KEY_DIM + ki] += k_h[ki] * delta;
+ for (int ki = 0; ki < cfg.linear_key_dim; ki++) {
+ S[vi * cfg.linear_key_dim + ki] += k_h[ki] * delta;
}
}
// Step 3: Output: y[vi] = sum_ki(S[vi,ki] * q[ki])
- float *q_h = lin_q + kh * LINEAR_KEY_DIM;
- float *o_h = out_values + vh * LINEAR_VALUE_DIM;
- for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) {
+ float *q_h = lin_q + kh * cfg.linear_key_dim;
+ float *o_h = out_values + vh * cfg.linear_value_dim;
+ for (int vi = 0; vi < cfg.linear_value_dim; vi++) {
float sum = 0.0f;
- for (int ki = 0; ki < LINEAR_KEY_DIM; ki++) {
- sum += S[vi * LINEAR_KEY_DIM + ki] * q_h[ki];
+ for (int ki = 0; ki < cfg.linear_key_dim; ki++) {
+ sum += S[vi * cfg.linear_key_dim + ki] * q_h[ki];
}
o_h[vi] = sum;
}
@@ -2583,20 +2889,20 @@ static void linear_attention_forward(
snprintf(name, sizeof(name), "model.layers.%d.linear_attn.norm.weight", layer_idx);
uint16_t *gated_norm_w = get_tensor_ptr(wf, name);
- float *gated_out = calloc(LINEAR_TOTAL_VALUE, sizeof(float));
- for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) {
- float *oh = out_values + vh * LINEAR_VALUE_DIM;
- float *zh = z + vh * LINEAR_VALUE_DIM;
- float *gh = gated_out + vh * LINEAR_VALUE_DIM;
+ float *gated_out = calloc(cfg.linear_total_value, sizeof(float));
+ for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) {
+ float *oh = out_values + vh * cfg.linear_value_dim;
+ float *zh = z + vh * cfg.linear_value_dim;
+ float *gh = gated_out + vh * cfg.linear_value_dim;
if (gated_norm_w) {
- cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, LINEAR_VALUE_DIM, RMS_NORM_EPS);
+ cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, cfg.linear_value_dim, cfg.rms_norm_eps);
} else {
- memcpy(gh, oh, LINEAR_VALUE_DIM * sizeof(float));
+ memcpy(gh, oh, cfg.linear_value_dim * sizeof(float));
}
}
// ---- Output projection: [value_dim=8192] -> [hidden_dim=4096] ----
- float *attn_out = calloc(HIDDEN_DIM, sizeof(float));
+ float *attn_out = calloc(cfg.hidden_dim, sizeof(float));
snprintf(name, sizeof(name), "model.layers.%d.linear_attn.out_proj.weight", layer_idx);
uint32_t *out_w = get_tensor_ptr(wf, name);
snprintf(name, sizeof(name), "model.layers.%d.linear_attn.out_proj.scales", layer_idx);
@@ -2604,20 +2910,20 @@ static void linear_attention_forward(
snprintf(name, sizeof(name), "model.layers.%d.linear_attn.out_proj.biases", layer_idx);
uint16_t *out_b = get_tensor_ptr(wf, name);
if (out_w && out_s && out_b) {
- fast_dequant_matvec(out_w, out_s, out_b, gated_out, attn_out, HIDDEN_DIM,
- LINEAR_TOTAL_VALUE, GROUP_SIZE);
+ fast_dequant_matvec(out_w, out_s, out_b, gated_out, attn_out, cfg.hidden_dim,
+ cfg.linear_total_value, cfg.group_size);
}
// ---- Residual ----
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
hidden[i] = residual[i] + attn_out[i];
}
if (la_debug) {
fprintf(stderr, "[LA-DBG] AFTER layer=%d out_proj_rms=%.6f gated_rms=%.6f hidden_rms=%.6f\n",
- layer_idx, vec_rms(attn_out, HIDDEN_DIM),
- vec_rms(gated_out, LINEAR_TOTAL_VALUE),
- vec_rms(hidden, HIDDEN_DIM));
+ layer_idx, vec_rms(attn_out, cfg.hidden_dim),
+ vec_rms(gated_out, cfg.linear_total_value),
+ vec_rms(hidden, cfg.hidden_dim));
}
free(normed);
@@ -2642,7 +2948,7 @@ static void linear_attention_forward(
static void moe_forward(
WeightFile *wf,
int layer_idx,
- float *hidden, // [HIDDEN_DIM] in/out
+ float *hidden, // [cfg.hidden_dim] in/out
const char *model_path __attribute__((unused)),
int K, // number of active experts (e.g. 4)
int packed_fd // fd for this layer's packed expert file (-1 if not available)
@@ -2652,19 +2958,19 @@ static void moe_forward(
int moe_dump = 0;
char name[256];
- float *h_post = malloc(HIDDEN_DIM * sizeof(float));
- float *h_mid = malloc(HIDDEN_DIM * sizeof(float));
- cpu_vec_copy(h_mid, hidden, HIDDEN_DIM);
+ float *h_post = malloc(cfg.hidden_dim * sizeof(float));
+ float *h_mid = malloc(cfg.hidden_dim * sizeof(float));
+ cpu_vec_copy(h_mid, hidden, cfg.hidden_dim);
// ---- Post-attention LayerNorm ----
snprintf(name, sizeof(name), "model.layers.%d.post_attention_layernorm.weight", layer_idx);
uint16_t *norm_w = get_tensor_ptr(wf, name);
- cpu_rms_norm(hidden, norm_w, h_post, HIDDEN_DIM, RMS_NORM_EPS);
+ cpu_rms_norm(hidden, norm_w, h_post, cfg.hidden_dim, cfg.rms_norm_eps);
// ---- Batch routing gate + shared expert gate/up + shared_expert_gate (4 matmuls, 1 commit) ----
- float *gate_scores = calloc(NUM_EXPERTS, sizeof(float));
- float *shared_gate = calloc(SHARED_INTERMEDIATE, sizeof(float));
- float *shared_up = calloc(SHARED_INTERMEDIATE, sizeof(float));
+ float *gate_scores = calloc(cfg.num_experts, sizeof(float));
+ float *shared_gate = calloc(cfg.shared_intermediate, sizeof(float));
+ float *shared_up = calloc(cfg.shared_intermediate, sizeof(float));
float shared_gate_score = 0.0f;
snprintf(name, sizeof(name), "model.layers.%d.mlp.gate.weight", layer_idx);
@@ -2699,21 +3005,21 @@ static void moe_forward(
if (gate_w && gate_s && gate_b && sgw && sgs && sgb &&
suw && sus && sub && seg_w && seg_s && seg_b) {
BatchMatvecSpec moe_specs[4] = {
- { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 },
- { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 },
- { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 },
- { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 },
+ { gate_w, gate_s, gate_b, gate_scores, (uint32_t)cfg.num_experts, cfg.hidden_dim, cfg.group_size, 0 },
+ { sgw, sgs, sgb, shared_gate, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 1 },
+ { suw, sus, sub, shared_up, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 2 },
+ { seg_w, seg_s, seg_b, &shared_gate_score, 1, cfg.hidden_dim, cfg.group_size, 3 },
};
- fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4);
+ fast_batch_matvec(h_post, cfg.hidden_dim, moe_specs, 4);
}
// Softmax routing scores
- cpu_softmax(gate_scores, NUM_EXPERTS);
+ cpu_softmax(gate_scores, cfg.num_experts);
// Top-K expert selection
int expert_indices[64];
float expert_weights[64];
- cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights);
+ cpu_topk(gate_scores, cfg.num_experts, K, expert_indices, expert_weights);
cpu_normalize_weights(expert_weights, K);
if (moe_dump) {
@@ -2723,10 +3029,10 @@ static void moe_forward(
}
// ---- Routed expert computation ----
- float *moe_out = calloc(HIDDEN_DIM, sizeof(float));
+ float *moe_out = calloc(cfg.hidden_dim, sizeof(float));
if (packed_fd >= 0) {
- float *expert_out = malloc(HIDDEN_DIM * sizeof(float));
+ float *expert_out = malloc(cfg.hidden_dim * sizeof(float));
size_t esz = active_expert_size();
for (int k = 0; k < K; k++) {
@@ -2756,26 +3062,26 @@ static void moe_forward(
}
uint32_t *gw = (uint32_t *)expert_data;
- uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : 2097152));
- uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : 2228224));
- uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : 2359296));
- uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : 4456448));
- uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : 4587520));
- uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : 4718592));
- uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : 6815744));
- uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : 6946816));
-
- float *gate_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float));
- float *up_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float));
- float *act_out = malloc(MOE_INTERMEDIATE * sizeof(float));
+ uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_s_off_2 : cfg.gate_s_off_4));
+ uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_b_off_2 : cfg.gate_b_off_4));
+ uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.up_w_off_2 : cfg.up_w_off_4));
+ uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_s_off_2 : cfg.up_s_off_4));
+ uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_b_off_2 : cfg.up_b_off_4));
+ uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.down_w_off_2 : cfg.down_w_off_4));
+ uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_s_off_2 : cfg.down_s_off_4));
+ uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_b_off_2 : cfg.down_b_off_4));
+
+ float *gate_proj_out = malloc(cfg.moe_intermediate * sizeof(float));
+ float *up_proj_out = malloc(cfg.moe_intermediate * sizeof(float));
+ float *act_out = malloc(cfg.moe_intermediate * sizeof(float));
cpu_dequant_matvec(gw, gs_p, gb_p, h_post, gate_proj_out,
- MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE);
+ cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size);
cpu_dequant_matvec(uw, us_p, ub_p, h_post, up_proj_out,
- MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE);
- cpu_swiglu(gate_proj_out, up_proj_out, act_out, MOE_INTERMEDIATE);
+ cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size);
+ cpu_swiglu(gate_proj_out, up_proj_out, act_out, cfg.moe_intermediate);
cpu_dequant_matvec(dw, ds_p, db_p, act_out, expert_out,
- HIDDEN_DIM, MOE_INTERMEDIATE, GROUP_SIZE);
+ cfg.hidden_dim, cfg.moe_intermediate, cfg.group_size);
free(gate_proj_out);
free(up_proj_out);
@@ -2786,31 +3092,31 @@ static void moe_forward(
// Accumulate weighted
if (moe_dump) {
fprintf(stderr, "[MOE-DUMP] expert[%d] out_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- eidx, vec_rms(expert_out, HIDDEN_DIM),
+ eidx, vec_rms(expert_out, cfg.hidden_dim),
expert_out[0], expert_out[1], expert_out[2], expert_out[3], expert_out[4]);
}
- cpu_vec_madd(moe_out, expert_out, expert_weights[k], HIDDEN_DIM);
+ cpu_vec_madd(moe_out, expert_out, expert_weights[k], cfg.hidden_dim);
}
free(expert_out);
}
// ---- Shared expert SwiGLU (gate_proj + up_proj already computed above) ----
- float *shared_out = calloc(HIDDEN_DIM, sizeof(float));
- float *shared_act = calloc(SHARED_INTERMEDIATE, sizeof(float));
- cpu_swiglu(shared_gate, shared_up, shared_act, SHARED_INTERMEDIATE);
+ float *shared_out = calloc(cfg.hidden_dim, sizeof(float));
+ float *shared_act = calloc(cfg.shared_intermediate, sizeof(float));
+ cpu_swiglu(shared_gate, shared_up, shared_act, cfg.shared_intermediate);
if (moe_dump) {
fprintf(stderr, "[MOE-DUMP] layer=%d h_post_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- layer_idx, vec_rms(h_post, HIDDEN_DIM), h_post[0], h_post[1], h_post[2], h_post[3], h_post[4]);
+ layer_idx, vec_rms(h_post, cfg.hidden_dim), h_post[0], h_post[1], h_post[2], h_post[3], h_post[4]);
fprintf(stderr, "[MOE-DUMP] gate_proj_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- vec_rms(shared_gate, SHARED_INTERMEDIATE),
+ vec_rms(shared_gate, cfg.shared_intermediate),
shared_gate[0], shared_gate[1], shared_gate[2], shared_gate[3], shared_gate[4]);
fprintf(stderr, "[MOE-DUMP] up_proj_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- vec_rms(shared_up, SHARED_INTERMEDIATE),
+ vec_rms(shared_up, cfg.shared_intermediate),
shared_up[0], shared_up[1], shared_up[2], shared_up[3], shared_up[4]);
fprintf(stderr, "[MOE-DUMP] swiglu_rms=%.6f first5=[%.6f,%.6f,%.6f,%.6f,%.6f]\n",
- vec_rms(shared_act, SHARED_INTERMEDIATE),
+ vec_rms(shared_act, cfg.shared_intermediate),
shared_act[0], shared_act[1], shared_act[2], shared_act[3], shared_act[4]);
}
@@ -2822,28 +3128,28 @@ static void moe_forward(
snprintf(name, sizeof(name), "model.layers.%d.mlp.shared_expert.down_proj.biases", layer_idx);
uint16_t *sdb = get_tensor_ptr(wf, name);
if (sdw && sds && sdb) {
- fast_dequant_matvec(sdw, sds, sdb, shared_act, shared_out, HIDDEN_DIM,
- SHARED_INTERMEDIATE, GROUP_SIZE);
+ fast_dequant_matvec(sdw, sds, sdb, shared_act, shared_out, cfg.hidden_dim,
+ cfg.shared_intermediate, cfg.group_size);
}
// ---- Shared expert gate (sigmoid) -- already computed above ----
float shared_weight = cpu_sigmoid(shared_gate_score);
// Scale shared expert output
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
shared_out[i] *= shared_weight;
}
// ---- Combine: hidden = h_mid + moe_out + shared_out ----
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
hidden[i] = h_mid[i] + moe_out[i] + shared_out[i];
}
if (moe_debug) {
fprintf(stderr, "[MOE-DBG] layer=%d h_mid_rms=%.4f moe_rms=%.4f shared_rms=%.4f shared_gate=%.4f hidden_rms=%.4f\n",
- layer_idx, vec_rms(h_mid, HIDDEN_DIM), vec_rms(moe_out, HIDDEN_DIM),
- vec_rms(shared_out, HIDDEN_DIM), shared_weight,
- vec_rms(hidden, HIDDEN_DIM));
+ layer_idx, vec_rms(h_mid, cfg.hidden_dim), vec_rms(moe_out, cfg.hidden_dim),
+ vec_rms(shared_out, cfg.hidden_dim), shared_weight,
+ vec_rms(hidden, cfg.hidden_dim));
}
free(h_post);
@@ -2872,7 +3178,7 @@ static void embed_lookup(WeightFile *wf, int token_id, float *out) {
if (!w_info || !s_info || !b_info) {
fprintf(stderr, "ERROR: embedding tensors not found\n");
- memset(out, 0, HIDDEN_DIM * sizeof(float));
+ memset(out, 0, cfg.hidden_dim * sizeof(float));
return;
}
@@ -2888,7 +3194,7 @@ static void embed_lookup(WeightFile *wf, int token_id, float *out) {
const uint16_t *s_row = S + (size_t)token_id * num_groups;
const uint16_t *b_row = B + (size_t)token_id * num_groups;
- int group_size = HIDDEN_DIM / num_groups; // 4096/64 = 64
+ int group_size = cfg.hidden_dim / num_groups; // 4096/64 = 64
int packed_per_group = group_size / 8; // 8
for (int g = 0; g < num_groups; g++) {
@@ -2930,14 +3236,14 @@ static void lm_head_forward(WeightFile *wf, const float *hidden, float *logits)
uint16_t *B = (uint16_t *)((char *)wf->data + b_info->offset);
// Full matmul — use GPU if available (248320 output rows!)
- fast_dequant_matvec(W, S, B, hidden, logits, VOCAB_SIZE, HIDDEN_DIM, GROUP_SIZE);
+ fast_dequant_matvec(W, S, B, hidden, logits, cfg.vocab_size, cfg.hidden_dim, cfg.group_size);
}
// ============================================================================
// Parallel I/O infrastructure for expert pread (from proven main.m pattern)
// ============================================================================
-#define NUM_IO_THREADS 4 // 4 threads for K=4 experts (one per expert)
+#define NUM_IO_THREADS 8 // 8 threads for K=8 experts (one per expert)
typedef struct {
int fd;
@@ -3206,7 +3512,7 @@ static int parallel_pread_experts_into(
typedef struct {
int layer_idx;
int expert_idx;
- id buffer; // Metal buffer holding EXPERT_SIZE bytes
+ id buffer; // Metal buffer holding cfg.expert_size_4bit bytes
uint64_t last_used; // monotonic counter for LRU ordering
} ExpertCacheEntry;
@@ -3215,7 +3521,7 @@ static int parallel_pread_experts_into(
int max_entries;
int num_entries;
int used_entries;
- int entry_idx[NUM_LAYERS][NUM_EXPERTS];
+ int *entry_idx; // flattened [num_layers * num_experts], -1 = not cached
uint64_t access_counter; // monotonic, incremented on every access
id device; // for allocating new Metal buffers
// Stats
@@ -3238,14 +3544,13 @@ static int parallel_pread_experts_into(
// - Loads into scratch buffers (no cache pollution)
// - Uses CMD1_wait idle time (no additional CPU cost)
// - Only sync-preads misses (not all K experts)
-static int g_pred_experts[60][MAX_K]; // previous token's expert indices per layer
-static int g_pred_count[60]; // how many experts stored per layer
static int g_pred_valid = 0; // 1 after first token completes (predictions available)
// g_pred_enabled, g_pred_hits, g_pred_misses, g_pred_layers declared near timing (line ~163)
static ExpertLRUCache *expert_cache_new(id device, int max_entries) {
ExpertLRUCache *cache = calloc(1, sizeof(ExpertLRUCache));
cache->entries = calloc(max_entries, sizeof(ExpertCacheEntry));
+ cache->entry_idx = malloc(cfg.num_layers * cfg.num_experts * sizeof(int));
cache->max_entries = max_entries;
cache->num_entries = 0;
cache->used_entries = 0;
@@ -3253,9 +3558,9 @@ static int parallel_pread_experts_into(
cache->device = device;
cache->hits = 0;
cache->misses = 0;
- for (int l = 0; l < NUM_LAYERS; l++) {
- for (int e = 0; e < NUM_EXPERTS; e++) {
- cache->entry_idx[l][e] = -1;
+ for (int l = 0; l < cfg.num_layers; l++) {
+ for (int e = 0; e < cfg.num_experts; e++) {
+ cache->entry_idx[(l) * cfg.num_experts + (e)] = -1;
}
}
// Pre-allocate ALL Metal buffers at startup (avoids allocation overhead at runtime)
@@ -3294,7 +3599,7 @@ static void expert_cache_free(ExpertLRUCache *cache) {
// Lookup: returns the cached Metal buffer if found, otherwise NULL.
// On hit, updates the LRU timestamp.
static id expert_cache_lookup(ExpertLRUCache *cache, int layer_idx, int expert_idx) {
- int idx = cache->entry_idx[layer_idx][expert_idx];
+ int idx = cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)];
if (idx >= 0) {
cache->entries[idx].last_used = ++cache->access_counter;
cache->hits++;
@@ -3311,7 +3616,7 @@ static void expert_cache_free(ExpertLRUCache *cache) {
static id expert_cache_insert(ExpertLRUCache *cache, int layer_idx, int expert_idx) {
id buf = nil;
- int existing = cache->entry_idx[layer_idx][expert_idx];
+ int existing = cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)];
if (existing >= 0) {
cache->entries[existing].last_used = ++cache->access_counter;
return cache->entries[existing].buffer;
@@ -3328,7 +3633,7 @@ static void expert_cache_free(ExpertLRUCache *cache) {
cache->entries[target].layer_idx = layer_idx;
cache->entries[target].expert_idx = expert_idx;
cache->entries[target].last_used = ++cache->access_counter;
- cache->entry_idx[layer_idx][expert_idx] = target;
+ cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)] = target;
return buf;
}
@@ -3347,13 +3652,13 @@ static void expert_cache_free(ExpertLRUCache *cache) {
int old_expert = cache->entries[lru_idx].expert_idx;
cache_telemetry_evict(old_layer, old_expert);
if (old_layer >= 0 && old_expert >= 0) {
- cache->entry_idx[old_layer][old_expert] = -1;
+ cache->entry_idx[(old_layer) * cfg.num_experts + (old_expert)] = -1;
}
buf = cache->entries[lru_idx].buffer;
cache->entries[lru_idx].layer_idx = layer_idx;
cache->entries[lru_idx].expert_idx = expert_idx;
cache->entries[lru_idx].last_used = ++cache->access_counter;
- cache->entry_idx[layer_idx][expert_idx] = lru_idx;
+ cache->entry_idx[(layer_idx) * cfg.num_experts + (expert_idx)] = lru_idx;
return buf;
}
@@ -3365,7 +3670,7 @@ static void expert_cache_free(ExpertLRUCache *cache) {
// ============================================================================
typedef struct {
- void **data; // [max_entries] page-aligned malloc'd EXPERT_SIZE buffers
+ void **data; // [max_entries] page-aligned malloc'd cfg.expert_size_4bit buffers
id __strong *metal_bufs; // [max_entries] zero-copy Metal buffer wrappers
int *layer_idx; // [max_entries] layer index for each entry
int *expert_idx; // [max_entries] expert index for each entry
@@ -3373,7 +3678,7 @@ static void expert_cache_free(ExpertLRUCache *cache) {
int max_entries;
int num_entries;
int used_entries;
- int entry_idx[NUM_LAYERS][NUM_EXPERTS];
+ int *entry_idx; // flattened [num_layers * num_experts], -1 = not cached
uint64_t access_counter;
uint64_t hits;
uint64_t misses;
@@ -3388,15 +3693,16 @@ static void expert_cache_free(ExpertLRUCache *cache) {
cache->layer_idx = calloc(max_entries, sizeof(int));
cache->expert_idx = calloc(max_entries, sizeof(int));
cache->last_used = calloc(max_entries, sizeof(uint64_t));
+ cache->entry_idx = malloc(cfg.num_layers * cfg.num_experts * sizeof(int));
cache->max_entries = max_entries;
cache->num_entries = 0;
cache->used_entries = 0;
cache->access_counter = 0;
cache->hits = 0;
cache->misses = 0;
- for (int l = 0; l < NUM_LAYERS; l++) {
- for (int e = 0; e < NUM_EXPERTS; e++) {
- cache->entry_idx[l][e] = -1;
+ for (int l = 0; l < cfg.num_layers; l++) {
+ for (int e = 0; e < cfg.num_experts; e++) {
+ cache->entry_idx[(l) * cfg.num_experts + (e)] = -1;
}
}
@@ -3440,7 +3746,7 @@ static void expert_cache_free(ExpertLRUCache *cache) {
// Lookup: returns Metal buffer wrapping cached data, or nil. Zero-copy dispatch.
static id malloc_cache_lookup(MallocExpertCache *cache, int layer, int expert) {
- int idx = cache->entry_idx[layer][expert];
+ int idx = cache->entry_idx[(layer) * cfg.num_experts + (expert)];
if (idx >= 0) {
cache->last_used[idx] = ++cache->access_counter;
cache->hits++;
@@ -3455,7 +3761,7 @@ static void expert_cache_free(ExpertLRUCache *cache) {
// Insert: evict LRU if needed, return entry index for pread target.
// Returns the Metal buffer for this entry (caller should pread into cache->data[idx]).
static id malloc_cache_insert(MallocExpertCache *cache, int layer, int expert, int *out_idx) {
- int existing = cache->entry_idx[layer][expert];
+ int existing = cache->entry_idx[(layer) * cfg.num_experts + (expert)];
if (existing >= 0) {
cache->last_used[existing] = ++cache->access_counter;
if (out_idx) *out_idx = existing;
@@ -3480,14 +3786,14 @@ static void expert_cache_free(ExpertLRUCache *cache) {
}
cache_telemetry_evict(cache->layer_idx[target], cache->expert_idx[target]);
if (cache->layer_idx[target] >= 0 && cache->expert_idx[target] >= 0) {
- cache->entry_idx[cache->layer_idx[target]][cache->expert_idx[target]] = -1;
+ cache->entry_idx[(cache->layer_idx[target]) * cfg.num_experts + (cache->expert_idx[target])] = -1;
}
}
cache->layer_idx[target] = layer;
cache->expert_idx[target] = expert;
cache->last_used[target] = ++cache->access_counter;
- cache->entry_idx[layer][expert] = target;
+ cache->entry_idx[(layer) * cfg.num_experts + (expert)] = target;
if (out_idx) *out_idx = target;
return cache->metal_bufs[target];
}
@@ -3643,7 +3949,7 @@ static void infer_prefetch_shutdown(void) {
// ============================================================================
// Per-layer weight pointer cache — built once, eliminates 40+ snprintf+lookup
-// per layer per token. With 60 layers and 15 tokens = 36,000 lookups saved.
+// per layer per token. With 40 layers and 15 tokens = 24,000 lookups saved.
// ============================================================================
typedef struct {
@@ -3677,16 +3983,33 @@ static void infer_prefetch_shutdown(void) {
uint32_t *seg_w; uint16_t *seg_s, *seg_b; // shared_expert_gate
} LayerWeightCache;
-static LayerWeightCache layer_cache[NUM_LAYERS];
+static LayerWeightCache *layer_cache = NULL;
static int layer_cache_built = 0;
+// Allocate all dynamic tracking arrays (must be called after load_model_config)
+static void alloc_tracking_arrays(void) {
+ int nl = cfg.num_layers;
+ int ne = cfg.num_experts;
+ int seen_bytes_per_layer = (ne + 7) / 8;
+
+ g_expert_freq = calloc(nl * ne, sizeof(int));
+ g_expert_seen = calloc(nl * seen_bytes_per_layer, sizeof(uint8_t));
+ g_lz4_index = calloc(nl, sizeof(void *));
+ g_cache_seen = calloc(nl * ne, sizeof(uint8_t));
+ g_cache_last_touch_token = calloc(nl * ne, sizeof(uint64_t));
+ g_cache_last_evict_token = calloc(nl * ne, sizeof(uint64_t));
+ g_pred_experts = calloc(nl * MAX_K, sizeof(int));
+ g_pred_count = calloc(nl, sizeof(int));
+ layer_cache = calloc(nl, sizeof(LayerWeightCache));
+}
+
static void build_layer_cache(WeightFile *wf) {
if (layer_cache_built) return;
char name[256];
- for (int i = 0; i < NUM_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_layers; i++) {
LayerWeightCache *lc = &layer_cache[i];
- int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0);
+ int is_full = cfg.is_full_attn[i];
// Norms
snprintf(name, sizeof(name), "model.layers.%d.input_layernorm.weight", i);
@@ -3800,7 +4123,7 @@ static void build_layer_cache(WeightFile *wf) {
}
layer_cache_built = 1;
- printf("[cache] Pre-computed weight pointers for %d layers\n", NUM_LAYERS);
+ printf("[cache] Pre-computed weight pointers for %d layers\n", cfg.num_layers);
}
// ============================================================================
@@ -3819,13 +4142,13 @@ static void build_layer_cache(WeightFile *wf) {
float expert_weights[MAX_K]; // routing weights for weighted accumulation
int valid[MAX_K]; // which experts loaded successfully
int actual_K; // number of experts
- float h_mid[HIDDEN_DIM]; // saved h_mid for final combine
+ float *h_mid; // [hidden_dim] saved h_mid for final combine
float shared_gate_score; // saved shared expert gate score
float *hidden; // pointer to hidden state (for writing final result)
int layer_idx; // which layer produced this deferred state
} DeferredExpertState;
-static DeferredExpertState g_deferred = { .active = 0 };
+static DeferredExpertState g_deferred = { .active = 0, .h_mid = NULL };
// Wait for the deferred GPU expert command buffer to complete.
// Split from finalize so timing can be measured independently.
@@ -3846,30 +4169,30 @@ static void finalize_deferred_experts(void) {
// buf_input already has the normalized input for the next layer's CMD1.
// Just read back hidden (needed for the residual connection in future layers).
memcpy(g_deferred.hidden, [g_metal->buf_moe_hidden contents],
- HIDDEN_DIM * sizeof(float));
+ cfg.hidden_dim * sizeof(float));
} else {
// CPU-side combine (original path)
// Read back and accumulate routed expert outputs
- float moe_out[HIDDEN_DIM];
+ float moe_out[cfg.hidden_dim];
memset(moe_out, 0, sizeof(moe_out));
for (int k = 0; k < g_deferred.actual_K; k++) {
if (!g_deferred.valid[k]) continue;
float *expert_result = (float *)[g_metal->buf_multi_expert_out[k] contents];
- cpu_vec_madd(moe_out, expert_result, g_deferred.expert_weights[k], HIDDEN_DIM);
+ cpu_vec_madd(moe_out, expert_result, g_deferred.expert_weights[k], cfg.hidden_dim);
}
// Read shared expert result
- float shared_out[HIDDEN_DIM];
- memcpy(shared_out, [g_metal->buf_shared_out contents], HIDDEN_DIM * sizeof(float));
+ float shared_out[cfg.hidden_dim];
+ memcpy(shared_out, [g_metal->buf_shared_out contents], cfg.hidden_dim * sizeof(float));
// Apply shared expert gate
float shared_weight = cpu_sigmoid(g_deferred.shared_gate_score);
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
shared_out[i] *= shared_weight;
}
// Final combine: hidden = h_mid + moe_out + shared_out
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
g_deferred.hidden[i] = g_deferred.h_mid[i] + moe_out[i] + shared_out[i];
}
}
@@ -3936,69 +4259,69 @@ static void discard_deferred_experts(void) {
// 4. GPU-side combine in CMD3 (eliminates CPU deferred_wait + combine + norm)
// ============================================================================
-// Static scratch buffers — allocated once, reused across all 60 layers per token.
+// Static scratch buffers — allocated once, reused across all 40 layers per token.
// Eliminates ~20 malloc/free per layer = ~1200 alloc/free per token.
-static float *s_normed = NULL; // [HIDDEN_DIM]
-static float *s_residual = NULL; // [HIDDEN_DIM]
-static float *s_attn_proj = NULL; // [HIDDEN_DIM]
-static float *s_h_post = NULL; // [HIDDEN_DIM]
-static float *s_h_mid = NULL; // [HIDDEN_DIM]
-static float *s_gate_scores = NULL; // [NUM_EXPERTS]
-static float *s_spec_gate_scores = NULL; // [NUM_EXPERTS] speculative routing scratch
+static float *s_normed = NULL; // [cfg.hidden_dim]
+static float *s_residual = NULL; // [cfg.hidden_dim]
+static float *s_attn_proj = NULL; // [cfg.hidden_dim]
+static float *s_h_post = NULL; // [cfg.hidden_dim]
+static float *s_h_mid = NULL; // [cfg.hidden_dim]
+static float *s_gate_scores = NULL; // [cfg.num_experts]
+static float *s_spec_gate_scores = NULL; // [cfg.num_experts] speculative routing scratch
static int s_spec_indices[MAX_K]; // speculative routing predicted expert indices
static int s_spec_count = 0; // number of speculative predictions this layer
-static float *s_shared_gate = NULL; // [SHARED_INTERMEDIATE]
-static float *s_shared_up = NULL; // [SHARED_INTERMEDIATE]
-static float *s_moe_out = NULL; // [HIDDEN_DIM]
-static float *s_shared_out = NULL; // [HIDDEN_DIM]
+static float *s_shared_gate = NULL; // [cfg.shared_intermediate]
+static float *s_shared_up = NULL; // [cfg.shared_intermediate]
+static float *s_moe_out = NULL; // [cfg.hidden_dim]
+static float *s_shared_out = NULL; // [cfg.hidden_dim]
// Full attention scratch
-static float *s_q_proj_out = NULL; // [NUM_ATTN_HEADS * HEAD_DIM * 2]
-static float *s_k_proj_out = NULL; // [NUM_KV_HEADS * HEAD_DIM]
-static float *s_v_proj_out = NULL; // [NUM_KV_HEADS * HEAD_DIM]
-static float *s_q = NULL; // [NUM_ATTN_HEADS * HEAD_DIM]
-static float *s_q_gate = NULL; // [NUM_ATTN_HEADS * HEAD_DIM]
-static float *s_attn_out = NULL; // [NUM_ATTN_HEADS * HEAD_DIM]
+static float *s_q_proj_out = NULL; // [cfg.num_attn_heads * cfg.head_dim * 2]
+static float *s_k_proj_out = NULL; // [cfg.num_kv_heads * cfg.head_dim]
+static float *s_v_proj_out = NULL; // [cfg.num_kv_heads * cfg.head_dim]
+static float *s_q = NULL; // [cfg.num_attn_heads * cfg.head_dim]
+static float *s_q_gate = NULL; // [cfg.num_attn_heads * cfg.head_dim]
+static float *s_attn_out = NULL; // [cfg.num_attn_heads * cfg.head_dim]
// Linear attention scratch
-static float *s_qkv_proj_out = NULL; // [LINEAR_CONV_DIM]
-static float *s_z_proj_out = NULL; // [LINEAR_TOTAL_VALUE]
-static float *s_beta_proj_out = NULL; // [LINEAR_NUM_V_HEADS]
-static float *s_alpha_proj_out = NULL; // [LINEAR_NUM_V_HEADS]
-static float *s_conv_out = NULL; // [LINEAR_CONV_DIM]
-static float *s_out_vals = NULL; // [LINEAR_TOTAL_VALUE]
-static float *s_gated_out = NULL; // [LINEAR_TOTAL_VALUE]
+static float *s_qkv_proj_out = NULL; // [cfg.linear_conv_dim]
+static float *s_z_proj_out = NULL; // [cfg.linear_total_value]
+static float *s_beta_proj_out = NULL; // [cfg.linear_num_v_heads]
+static float *s_alpha_proj_out = NULL; // [cfg.linear_num_v_heads]
+static float *s_conv_out = NULL; // [cfg.linear_conv_dim]
+static float *s_out_vals = NULL; // [cfg.linear_total_value]
+static float *s_gated_out = NULL; // [cfg.linear_total_value]
static void init_layer_scratch(void) {
if (s_normed) return; // already initialized
- s_normed = calloc(HIDDEN_DIM, sizeof(float));
- s_residual = calloc(HIDDEN_DIM, sizeof(float));
- s_attn_proj = calloc(HIDDEN_DIM, sizeof(float));
- s_h_post = calloc(HIDDEN_DIM, sizeof(float));
- s_h_mid = calloc(HIDDEN_DIM, sizeof(float));
- s_gate_scores = calloc(NUM_EXPERTS, sizeof(float));
- s_spec_gate_scores = calloc(NUM_EXPERTS, sizeof(float));
- s_shared_gate = calloc(SHARED_INTERMEDIATE, sizeof(float));
- s_shared_up = calloc(SHARED_INTERMEDIATE, sizeof(float));
- s_moe_out = calloc(HIDDEN_DIM, sizeof(float));
- s_shared_out = calloc(HIDDEN_DIM, sizeof(float));
- s_q_proj_out = calloc(NUM_ATTN_HEADS * HEAD_DIM * 2, sizeof(float));
- s_k_proj_out = calloc(NUM_KV_HEADS * HEAD_DIM, sizeof(float));
- s_v_proj_out = calloc(NUM_KV_HEADS * HEAD_DIM, sizeof(float));
- s_q = calloc(NUM_ATTN_HEADS * HEAD_DIM, sizeof(float));
- s_q_gate = calloc(NUM_ATTN_HEADS * HEAD_DIM, sizeof(float));
- s_attn_out = calloc(NUM_ATTN_HEADS * HEAD_DIM, sizeof(float));
- s_qkv_proj_out = calloc(LINEAR_CONV_DIM, sizeof(float));
- s_z_proj_out = calloc(LINEAR_TOTAL_VALUE, sizeof(float));
- s_beta_proj_out = calloc(LINEAR_NUM_V_HEADS, sizeof(float));
- s_alpha_proj_out = calloc(LINEAR_NUM_V_HEADS, sizeof(float));
- s_conv_out = calloc(LINEAR_CONV_DIM, sizeof(float));
- s_out_vals = calloc(LINEAR_TOTAL_VALUE, sizeof(float));
- s_gated_out = calloc(LINEAR_TOTAL_VALUE, sizeof(float));
+ s_normed = calloc(cfg.hidden_dim, sizeof(float));
+ s_residual = calloc(cfg.hidden_dim, sizeof(float));
+ s_attn_proj = calloc(cfg.hidden_dim, sizeof(float));
+ s_h_post = calloc(cfg.hidden_dim, sizeof(float));
+ s_h_mid = calloc(cfg.hidden_dim, sizeof(float));
+ s_gate_scores = calloc(cfg.num_experts, sizeof(float));
+ s_spec_gate_scores = calloc(cfg.num_experts, sizeof(float));
+ s_shared_gate = calloc(cfg.shared_intermediate, sizeof(float));
+ s_shared_up = calloc(cfg.shared_intermediate, sizeof(float));
+ s_moe_out = calloc(cfg.hidden_dim, sizeof(float));
+ s_shared_out = calloc(cfg.hidden_dim, sizeof(float));
+ s_q_proj_out = calloc(cfg.num_attn_heads * cfg.head_dim * 2, sizeof(float));
+ s_k_proj_out = calloc(cfg.num_kv_heads * cfg.head_dim, sizeof(float));
+ s_v_proj_out = calloc(cfg.num_kv_heads * cfg.head_dim, sizeof(float));
+ s_q = calloc(cfg.num_attn_heads * cfg.head_dim, sizeof(float));
+ s_q_gate = calloc(cfg.num_attn_heads * cfg.head_dim, sizeof(float));
+ s_attn_out = calloc(cfg.num_attn_heads * cfg.head_dim, sizeof(float));
+ s_qkv_proj_out = calloc(cfg.linear_conv_dim, sizeof(float));
+ s_z_proj_out = calloc(cfg.linear_total_value, sizeof(float));
+ s_beta_proj_out = calloc(cfg.linear_num_v_heads, sizeof(float));
+ s_alpha_proj_out = calloc(cfg.linear_num_v_heads, sizeof(float));
+ s_conv_out = calloc(cfg.linear_conv_dim, sizeof(float));
+ s_out_vals = calloc(cfg.linear_total_value, sizeof(float));
+ s_gated_out = calloc(cfg.linear_total_value, sizeof(float));
}
static void fused_layer_forward(
WeightFile *wf,
int layer_idx,
- float *hidden, // [HIDDEN_DIM] in/out
+ float *hidden, // [cfg.hidden_dim] in/out
KVCache *kv, // non-NULL for full attention layers
LinearAttnState *la_state, // non-NULL for linear attention layers
int pos, // position for RoPE
@@ -4026,8 +4349,8 @@ static void fused_layer_forward(
float *qkv_out = NULL, *z_out = NULL, *beta_out = NULL, *alpha_out = NULL;
if (is_full) {
- int q_proj_dim = NUM_ATTN_HEADS * HEAD_DIM * 2;
- int kv_dim = NUM_KV_HEADS * HEAD_DIM;
+ int q_proj_dim = cfg.num_attn_heads * cfg.head_dim * 2;
+ int kv_dim = cfg.num_kv_heads * cfg.head_dim;
q_proj_out = s_q_proj_out;
k_out = s_k_proj_out;
@@ -4035,14 +4358,14 @@ static void fused_layer_forward(
if (lc->q_w && lc->q_s && lc->q_b && lc->k_w && lc->k_s && lc->k_b &&
lc->v_w && lc->v_s && lc->v_b) {
- attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 };
- attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 };
- attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 };
+ attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, cfg.hidden_dim, cfg.group_size, 0 };
+ attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 1 };
+ attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, cfg.hidden_dim, cfg.group_size, 2 };
num_attn_specs = 3;
}
} else {
- int qkv_dim = LINEAR_CONV_DIM;
- int z_dim = LINEAR_TOTAL_VALUE;
+ int qkv_dim = cfg.linear_conv_dim;
+ int z_dim = cfg.linear_total_value;
qkv_out = s_qkv_proj_out;
z_out = s_z_proj_out;
@@ -4051,10 +4374,10 @@ static void fused_layer_forward(
if (lc->qkv_w && lc->qkv_s && lc->qkv_b && lc->z_w && lc->z_s && lc->z_b &&
lc->b_w && lc->b_s && lc->b_b && lc->a_w && lc->a_s && lc->a_b) {
- attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 };
- attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 };
- attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 };
- attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 };
+ attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, cfg.hidden_dim, cfg.group_size, 0 };
+ attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, cfg.hidden_dim, cfg.group_size, 1 };
+ attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 2 };
+ attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)cfg.linear_num_v_heads, cfg.hidden_dim, cfg.group_size, 3 };
num_attn_specs = 4;
}
}
@@ -4068,7 +4391,7 @@ static void fused_layer_forward(
// Pre-compute linear_layer_idx for GPU linear attention encoding in CMD1
int linear_layer_idx = -1;
if (!is_full) {
- linear_layer_idx = layer_idx - (layer_idx + 1) / FULL_ATTN_INTERVAL;
+ linear_layer_idx = cfg.linear_index[layer_idx];
}
// Can we run the full linear attention pipeline on GPU in CMD1?
int can_gpu_linear = (gpu_linear_attn_enabled &&
@@ -4076,7 +4399,7 @@ static void fused_layer_forward(
g_metal->conv1d_step && g_metal->rms_norm_qk &&
g_metal->compute_decay_beta && g_metal->gated_rms_norm &&
g_metal->wf_buf &&
- linear_layer_idx >= 0 && linear_layer_idx < NUM_LINEAR_LAYERS &&
+ linear_layer_idx >= 0 && linear_layer_idx < cfg.num_linear_layers &&
lc->conv1d_w && lc->A_log && lc->dt_bias && lc->gated_norm_w &&
!linear_attn_bypass);
@@ -4097,7 +4420,7 @@ static void fused_layer_forward(
// GPU linear attention: encode conv1d + normalize + decay/beta + delta-net + gated_norm into CMD1
if (can_gpu_linear && num_attn_specs == 4) {
// batch_out[0]=qkv(12288), [1]=z(8192), [2]=beta(64), [3]=alpha(64)
- uint32_t conv_dim = LINEAR_CONV_DIM;
+ uint32_t conv_dim = cfg.linear_conv_dim;
NSUInteger conv_w_off = (NSUInteger)((const char *)lc->conv1d_w - (const char *)[g_metal->wf_buf contents]);
// Enc L1: conv1d_step — input=batch_out[0], weights=conv1d_w, state=buf_conv_state, output=buf_conv_output
@@ -4117,16 +4440,16 @@ static void fused_layer_forward(
// Enc L2: rms_norm_qk — normalize q and k in conv_output in-place
{
- uint32_t key_dim = LINEAR_KEY_DIM; // 128
- float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM);
+ uint32_t key_dim = cfg.linear_key_dim; // 128
+ float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim);
id enc = [cmd1 computeCommandEncoder];
[enc setComputePipelineState:g_metal->rms_norm_qk];
[enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:0]; // q at offset 0
- [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:1]; // k at offset 2048 floats
+ [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:1]; // k at offset 2048 floats
[enc setBytes:&key_dim length:4 atIndex:2];
[enc setBytes:&inv_scale length:4 atIndex:3];
- [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_K_HEADS, 1, 1)
- threadsPerThreadgroup:MTLSizeMake(LINEAR_KEY_DIM, 1, 1)];
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_k_heads, 1, 1)
+ threadsPerThreadgroup:MTLSizeMake(cfg.linear_key_dim, 1, 1)];
[enc endEncoding];
}
@@ -4143,24 +4466,24 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4]; // g_decay output
[enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5]; // beta_gate output
[enc dispatchThreadgroups:MTLSizeMake(1, 1, 1)
- threadsPerThreadgroup:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)];
+ threadsPerThreadgroup:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)];
[enc endEncoding];
}
// Enc L4: gated_delta_net_step — the main recurrence
{
- uint32_t khpv = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS; // 4
+ uint32_t khpv = cfg.linear_num_v_heads / cfg.linear_num_k_heads; // 4
id enc = [cmd1 computeCommandEncoder];
[enc setComputePipelineState:g_metal->delta_net_step];
[enc setBuffer:g_metal->buf_delta_state[linear_layer_idx] offset:0 atIndex:0]; // persistent state
[enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:1]; // q (first 2048 floats)
- [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:2]; // k (next 2048)
- [enc setBuffer:g_metal->buf_conv_output offset:2 * LINEAR_TOTAL_KEY * sizeof(float) atIndex:3]; // v (next 8192)
+ [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:2]; // k (next 2048)
+ [enc setBuffer:g_metal->buf_conv_output offset:2 * cfg.linear_total_key * sizeof(float) atIndex:3]; // v (next 8192)
[enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4];
[enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5];
[enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:6]; // output [8192]
[enc setBytes:&khpv length:sizeof(khpv) atIndex:7];
- [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)
threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
[enc endEncoding];
}
@@ -4168,8 +4491,8 @@ static void fused_layer_forward(
// Enc L5: gated_rms_norm — normalize+gate delta-net output -> batch_out[6] for CMD2 o_proj
{
NSUInteger gnorm_w_off = (NSUInteger)((const char *)lc->gated_norm_w - (const char *)[g_metal->wf_buf contents]);
- uint32_t value_dim = LINEAR_VALUE_DIM; // 128
- float eps = RMS_NORM_EPS;
+ uint32_t value_dim = cfg.linear_value_dim; // 128
+ float eps = cfg.rms_norm_eps;
id enc = [cmd1 computeCommandEncoder];
[enc setComputePipelineState:g_metal->gated_rms_norm];
[enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:0]; // values [8192]
@@ -4178,8 +4501,8 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->batch_out[6] offset:0 atIndex:3]; // output -> batch_out[6] for CMD2
[enc setBytes:&value_dim length:4 atIndex:4];
[enc setBytes:&eps length:4 atIndex:5];
- [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)
- threadsPerThreadgroup:MTLSizeMake(LINEAR_VALUE_DIM, 1, 1)];
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)
+ threadsPerThreadgroup:MTLSizeMake(cfg.linear_value_dim, 1, 1)];
[enc endEncoding];
}
@@ -4207,14 +4530,14 @@ static void fused_layer_forward(
// Predictions overlap with CPU attn + CMD2 + routing (~0.6ms head start).
// Predicted experts that hit page cache (same as previous token) complete in ~0.1ms.
if (g_pred_enabled && g_pred_generating && g_pred_valid && packed_fd >= 0 &&
- g_metal->buf_multi_expert_data_B[0] && g_pred_count[layer_idx] > 0) {
- async_pread_start(packed_fd, g_pred_experts[layer_idx],
- g_pred_count[layer_idx],
+ g_metal->buf_multi_expert_data_B[0] && PRED_COUNT(layer_idx) > 0) {
+ async_pread_start(packed_fd, &PRED_EXPERT(layer_idx, 0),
+ PRED_COUNT(layer_idx),
g_metal->buf_multi_expert_data_B, mmap_base);
pred_started = 1;
}
// Set up residual for CMD2 (residual = hidden before this layer's attention)
- cpu_vec_copy(residual, hidden, HIDDEN_DIM);
+ cpu_vec_copy(residual, hidden, cfg.hidden_dim);
if (g_timing_enabled) { t1 = now_ms(); g_timing.deferred_cpu += t1 - t0; }
// No input_norm needed — CMD3 already computed it into buf_input.
@@ -4233,20 +4556,20 @@ static void fused_layer_forward(
// Input norm
if (g_timing_enabled) { t0 = now_ms(); }
- cpu_vec_copy(residual, hidden, HIDDEN_DIM);
- cpu_rms_norm(hidden, lc->input_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS);
+ cpu_vec_copy(residual, hidden, cfg.hidden_dim);
+ cpu_rms_norm(hidden, lc->input_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps);
if (g_timing_enabled) { t1 = now_ms(); g_timing.input_norm += t1 - t0; }
// Submit CMD1: attention projections
if (g_timing_enabled) { t0 = now_ms(); }
if (g_metal && g_metal->wf_buf && num_attn_specs > 0) {
- memcpy([g_metal->buf_input contents], normed, HIDDEN_DIM * sizeof(float));
+ memcpy([g_metal->buf_input contents], normed, cfg.hidden_dim * sizeof(float));
cmd1 = [g_metal->queue commandBuffer];
gpu_encode_batch_matvec(g_metal, cmd1, attn_specs, num_attn_specs);
// GPU linear attention: encode conv1d + normalize + decay/beta + delta-net + gated_norm into CMD1
if (can_gpu_linear && num_attn_specs == 4) {
- uint32_t conv_dim = LINEAR_CONV_DIM;
+ uint32_t conv_dim = cfg.linear_conv_dim;
NSUInteger conv_w_off = (NSUInteger)((const char *)lc->conv1d_w - (const char *)[g_metal->wf_buf contents]);
// Enc L1: conv1d_step
@@ -4266,16 +4589,16 @@ static void fused_layer_forward(
// Enc L2: rms_norm_qk
{
- uint32_t key_dim = LINEAR_KEY_DIM;
- float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM);
+ uint32_t key_dim = cfg.linear_key_dim;
+ float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim);
id enc = [cmd1 computeCommandEncoder];
[enc setComputePipelineState:g_metal->rms_norm_qk];
[enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:0];
- [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:1];
+ [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:1];
[enc setBytes:&key_dim length:4 atIndex:2];
[enc setBytes:&inv_scale length:4 atIndex:3];
- [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_K_HEADS, 1, 1)
- threadsPerThreadgroup:MTLSizeMake(LINEAR_KEY_DIM, 1, 1)];
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_k_heads, 1, 1)
+ threadsPerThreadgroup:MTLSizeMake(cfg.linear_key_dim, 1, 1)];
[enc endEncoding];
}
@@ -4292,24 +4615,24 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4];
[enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5];
[enc dispatchThreadgroups:MTLSizeMake(1, 1, 1)
- threadsPerThreadgroup:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)];
+ threadsPerThreadgroup:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)];
[enc endEncoding];
}
// Enc L4: gated_delta_net_step
{
- uint32_t khpv = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS;
+ uint32_t khpv = cfg.linear_num_v_heads / cfg.linear_num_k_heads;
id enc = [cmd1 computeCommandEncoder];
[enc setComputePipelineState:g_metal->delta_net_step];
[enc setBuffer:g_metal->buf_delta_state[linear_layer_idx] offset:0 atIndex:0];
[enc setBuffer:g_metal->buf_conv_output offset:0 atIndex:1];
- [enc setBuffer:g_metal->buf_conv_output offset:LINEAR_TOTAL_KEY * sizeof(float) atIndex:2];
- [enc setBuffer:g_metal->buf_conv_output offset:2 * LINEAR_TOTAL_KEY * sizeof(float) atIndex:3];
+ [enc setBuffer:g_metal->buf_conv_output offset:cfg.linear_total_key * sizeof(float) atIndex:2];
+ [enc setBuffer:g_metal->buf_conv_output offset:2 * cfg.linear_total_key * sizeof(float) atIndex:3];
[enc setBuffer:g_metal->buf_delta_g_decay offset:0 atIndex:4];
[enc setBuffer:g_metal->buf_delta_beta offset:0 atIndex:5];
[enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:6];
[enc setBytes:&khpv length:sizeof(khpv) atIndex:7];
- [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)
threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
[enc endEncoding];
}
@@ -4317,8 +4640,8 @@ static void fused_layer_forward(
// Enc L5: gated_rms_norm -> batch_out[6]
{
NSUInteger gnorm_w_off = (NSUInteger)((const char *)lc->gated_norm_w - (const char *)[g_metal->wf_buf contents]);
- uint32_t value_dim = LINEAR_VALUE_DIM;
- float eps = RMS_NORM_EPS;
+ uint32_t value_dim = cfg.linear_value_dim;
+ float eps = cfg.rms_norm_eps;
id enc = [cmd1 computeCommandEncoder];
[enc setComputePipelineState:g_metal->gated_rms_norm];
[enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:0];
@@ -4327,8 +4650,8 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->batch_out[6] offset:0 atIndex:3];
[enc setBytes:&value_dim length:4 atIndex:4];
[enc setBytes:&eps length:4 atIndex:5];
- [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)
- threadsPerThreadgroup:MTLSizeMake(LINEAR_VALUE_DIM, 1, 1)];
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)
+ threadsPerThreadgroup:MTLSizeMake(cfg.linear_value_dim, 1, 1)];
[enc endEncoding];
}
@@ -4375,17 +4698,17 @@ static void fused_layer_forward(
if (spec_routing_enabled && (g_expert_cache || g_malloc_cache) && packed_fd >= 0 && lc->gate_w) {
float *spec_scores = s_spec_gate_scores;
- memset(spec_scores, 0, NUM_EXPERTS * sizeof(float));
+ memset(spec_scores, 0, cfg.num_experts * sizeof(float));
// Gate projection matvec on pre-attention normed input (CPU, ~0.1ms for 512x4096)
cpu_dequant_matvec(lc->gate_w, lc->gate_s, lc->gate_b,
normed, spec_scores,
- NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE);
- cpu_softmax(spec_scores, NUM_EXPERTS);
+ cfg.num_experts, cfg.hidden_dim, cfg.group_size);
+ cpu_softmax(spec_scores, cfg.num_experts);
int spec_K = (K > MAX_K) ? MAX_K : K;
float spec_weights[MAX_K];
- cpu_topk(spec_scores, NUM_EXPERTS, spec_K, s_spec_indices, spec_weights);
+ cpu_topk(spec_scores, cfg.num_experts, spec_K, s_spec_indices, spec_weights);
s_spec_count = spec_K;
g_spec_route_attempts += spec_K;
@@ -4450,7 +4773,7 @@ static void fused_layer_forward(
if (g_timing_enabled) { t0 = now_ms(); }
float *attn_projected = s_attn_proj;
- memset(attn_projected, 0, HIDDEN_DIM * sizeof(float));
+ memset(attn_projected, 0, cfg.hidden_dim * sizeof(float));
// Pre-lookup o_proj / out_proj weights (used after attention compute)
// These are looked up NOW to avoid repeated snprintf later.
@@ -4460,10 +4783,10 @@ static void fused_layer_forward(
if (is_full) {
oproj_w = lc->o_w; oproj_s = lc->o_s; oproj_b = lc->o_b;
- oproj_in_dim = NUM_ATTN_HEADS * HEAD_DIM;
+ oproj_in_dim = cfg.num_attn_heads * cfg.head_dim;
} else if (!linear_attn_bypass) {
oproj_w = lc->out_proj_w; oproj_s = lc->out_proj_s; oproj_b = lc->out_proj_b;
- oproj_in_dim = LINEAR_TOTAL_VALUE;
+ oproj_in_dim = cfg.linear_total_value;
}
// All MoE weight pointers from cache (zero snprintf overhead)
@@ -4478,51 +4801,51 @@ static void fused_layer_forward(
if (is_full) {
// ---- Full attention CPU compute ----
- int q_proj_dim = NUM_ATTN_HEADS * HEAD_DIM * 2;
- int q_dim = NUM_ATTN_HEADS * HEAD_DIM;
- int kv_dim = NUM_KV_HEADS * HEAD_DIM;
+ int q_proj_dim = cfg.num_attn_heads * cfg.head_dim * 2;
+ int q_dim = cfg.num_attn_heads * cfg.head_dim;
+ int kv_dim = cfg.num_kv_heads * cfg.head_dim;
(void)q_proj_dim;
float *q = s_q;
float *q_gate = s_q_gate;
- for (int h = 0; h < NUM_ATTN_HEADS; h++) {
- float *src = q_proj_out + h * (2 * HEAD_DIM);
- memcpy(q + h * HEAD_DIM, src, HEAD_DIM * sizeof(float));
- memcpy(q_gate + h * HEAD_DIM, src + HEAD_DIM, HEAD_DIM * sizeof(float));
+ for (int h = 0; h < cfg.num_attn_heads; h++) {
+ float *src = q_proj_out + h * (2 * cfg.head_dim);
+ memcpy(q + h * cfg.head_dim, src, cfg.head_dim * sizeof(float));
+ memcpy(q_gate + h * cfg.head_dim, src + cfg.head_dim, cfg.head_dim * sizeof(float));
}
// Q/K RMSNorm
uint16_t *qnorm_w = lc->q_norm_w;
uint16_t *knorm_w = lc->k_norm_w;
if (qnorm_w) {
- for (int h = 0; h < NUM_ATTN_HEADS; h++) {
- float *qh = q + h * HEAD_DIM;
+ for (int h = 0; h < cfg.num_attn_heads; h++) {
+ float *qh = q + h * cfg.head_dim;
float sum_sq = 0.0f;
- for (int i = 0; i < HEAD_DIM; i++) sum_sq += qh[i] * qh[i];
- float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS);
- for (int i = 0; i < HEAD_DIM; i++) qh[i] = qh[i] * inv_rms * bf16_to_f32(qnorm_w[i]);
+ for (int i = 0; i < cfg.head_dim; i++) sum_sq += qh[i] * qh[i];
+ float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps);
+ for (int i = 0; i < cfg.head_dim; i++) qh[i] = qh[i] * inv_rms * bf16_to_f32(qnorm_w[i]);
}
}
if (knorm_w) {
- for (int h = 0; h < NUM_KV_HEADS; h++) {
- float *kh = k_out + h * HEAD_DIM;
+ for (int h = 0; h < cfg.num_kv_heads; h++) {
+ float *kh = k_out + h * cfg.head_dim;
float sum_sq = 0.0f;
- for (int i = 0; i < HEAD_DIM; i++) sum_sq += kh[i] * kh[i];
- float inv_rms = 1.0f / sqrtf(sum_sq / HEAD_DIM + RMS_NORM_EPS);
- for (int i = 0; i < HEAD_DIM; i++) kh[i] = kh[i] * inv_rms * bf16_to_f32(knorm_w[i]);
+ for (int i = 0; i < cfg.head_dim; i++) sum_sq += kh[i] * kh[i];
+ float inv_rms = 1.0f / sqrtf(sum_sq / cfg.head_dim + cfg.rms_norm_eps);
+ for (int i = 0; i < cfg.head_dim; i++) kh[i] = kh[i] * inv_rms * bf16_to_f32(knorm_w[i]);
}
}
// RoPE
- apply_rotary_emb(q, k_out, pos, NUM_ATTN_HEADS, NUM_KV_HEADS, HEAD_DIM, ROTARY_DIM);
+ apply_rotary_emb(q, k_out, pos, cfg.num_attn_heads, cfg.num_kv_heads, cfg.head_dim, cfg.rotary_dim);
// Update KV cache (CPU + GPU mirror)
int cache_pos = kv->len;
memcpy(kv->k_cache + cache_pos * kv_dim, k_out, kv_dim * sizeof(float));
memcpy(kv->v_cache + cache_pos * kv_dim, v_out, kv_dim * sizeof(float));
- int fa_idx = (layer_idx + 1) / FULL_ATTN_INTERVAL - 1;
- if (g_metal && g_metal->attn_scores_pipe && fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS) {
+ int fa_idx = cfg.full_attn_index[layer_idx];
+ if (g_metal && g_metal->attn_scores_pipe && fa_idx >= 0 && fa_idx < cfg.num_full_attn_layers) {
memcpy((float *)[g_metal->buf_kv_k[fa_idx] contents] + cache_pos * kv_dim,
k_out, kv_dim * sizeof(float));
memcpy((float *)[g_metal->buf_kv_v[fa_idx] contents] + cache_pos * kv_dim,
@@ -4531,15 +4854,15 @@ static void fused_layer_forward(
kv->len++;
// Scaled dot-product attention (GQA) — GPU or CPU
- int heads_per_kv = NUM_ATTN_HEADS / NUM_KV_HEADS;
- float scale = 1.0f / sqrtf((float)HEAD_DIM);
+ int heads_per_kv = cfg.num_attn_heads / cfg.num_kv_heads;
+ float scale = 1.0f / sqrtf((float)cfg.head_dim);
float *attn_out = s_attn_out;
memset(attn_out, 0, q_dim * sizeof(float));
// GPU attention: defer dispatches to CMD2 (fused into single cmd buffer).
// Only enabled when seq_len >= 32 (below that, CPU is faster).
int gpu_attn_ready = (g_metal && g_metal->attn_scores_pipe &&
- fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS &&
+ fa_idx >= 0 && fa_idx < cfg.num_full_attn_layers &&
kv->len >= 32 && kv->len < GPU_KV_SEQ);
if (gpu_attn_ready) {
@@ -4549,21 +4872,21 @@ static void fused_layer_forward(
// attn_out_for_oproj will be set to NULL below — CMD2 reads buf_attn_out
} else {
// CPU fallback
- for (int h = 0; h < NUM_ATTN_HEADS; h++) {
+ for (int h = 0; h < cfg.num_attn_heads; h++) {
int kv_h = h / heads_per_kv;
- float *qh = q + h * HEAD_DIM;
+ float *qh = q + h * cfg.head_dim;
float *scores = malloc(kv->len * sizeof(float));
for (int p = 0; p < kv->len; p++) {
- float *kp = kv->k_cache + p * kv_dim + kv_h * HEAD_DIM;
+ float *kp = kv->k_cache + p * kv_dim + kv_h * cfg.head_dim;
float dot = 0.0f;
- for (int d = 0; d < HEAD_DIM; d++) dot += qh[d] * kp[d];
+ for (int d = 0; d < cfg.head_dim; d++) dot += qh[d] * kp[d];
scores[p] = dot * scale;
}
cpu_softmax(scores, kv->len);
- float *oh = attn_out + h * HEAD_DIM;
+ float *oh = attn_out + h * cfg.head_dim;
for (int p = 0; p < kv->len; p++) {
- float *vp = kv->v_cache + p * kv_dim + kv_h * HEAD_DIM;
- for (int d = 0; d < HEAD_DIM; d++) oh[d] += scores[p] * vp[d];
+ float *vp = kv->v_cache + p * kv_dim + kv_h * cfg.head_dim;
+ for (int d = 0; d < cfg.head_dim; d++) oh[d] += scores[p] * vp[d];
}
free(scores);
}
@@ -4588,7 +4911,7 @@ static void fused_layer_forward(
} else {
// ---- Linear attention CPU compute ----
if (!linear_attn_bypass) {
- int qkv_dim = LINEAR_CONV_DIM;
+ int qkv_dim = cfg.linear_conv_dim;
// Conv1d step
uint16_t *conv_w = lc->conv1d_w;
@@ -4596,31 +4919,31 @@ static void fused_layer_forward(
memset(conv_out, 0, qkv_dim * sizeof(float));
if (conv_w) {
cpu_conv1d_step(la_state->conv_state, qkv_out, conv_w, conv_out,
- qkv_dim, CONV_KERNEL_SIZE);
+ qkv_dim, cfg.conv_kernel_size);
}
// Update conv state
memmove(la_state->conv_state, la_state->conv_state + qkv_dim,
- (CONV_KERNEL_SIZE - 2) * qkv_dim * sizeof(float));
- memcpy(la_state->conv_state + (CONV_KERNEL_SIZE - 2) * qkv_dim, qkv_out,
+ (cfg.conv_kernel_size - 2) * qkv_dim * sizeof(float));
+ memcpy(la_state->conv_state + (cfg.conv_kernel_size - 2) * qkv_dim, qkv_out,
qkv_dim * sizeof(float));
// Split into q, k, v
float *lin_q = conv_out;
- float *lin_k = conv_out + LINEAR_TOTAL_KEY;
- float *lin_v = conv_out + 2 * LINEAR_TOTAL_KEY;
+ float *lin_k = conv_out + cfg.linear_total_key;
+ float *lin_v = conv_out + 2 * cfg.linear_total_key;
// RMS normalize q and k
- float inv_scale = 1.0f / sqrtf((float)LINEAR_KEY_DIM);
- for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) {
- float *qh = lin_q + h * LINEAR_KEY_DIM;
- cpu_rms_norm_bare(qh, qh, LINEAR_KEY_DIM, 1e-6f);
+ float inv_scale = 1.0f / sqrtf((float)cfg.linear_key_dim);
+ for (int h = 0; h < cfg.linear_num_k_heads; h++) {
+ float *qh = lin_q + h * cfg.linear_key_dim;
+ cpu_rms_norm_bare(qh, qh, cfg.linear_key_dim, 1e-6f);
float q_scale = inv_scale * inv_scale;
- for (int d = 0; d < LINEAR_KEY_DIM; d++) qh[d] *= q_scale;
+ for (int d = 0; d < cfg.linear_key_dim; d++) qh[d] *= q_scale;
}
- for (int h = 0; h < LINEAR_NUM_K_HEADS; h++) {
- float *kh = lin_k + h * LINEAR_KEY_DIM;
- cpu_rms_norm_bare(kh, kh, LINEAR_KEY_DIM, 1e-6f);
- for (int d = 0; d < LINEAR_KEY_DIM; d++) kh[d] *= inv_scale;
+ for (int h = 0; h < cfg.linear_num_k_heads; h++) {
+ float *kh = lin_k + h * cfg.linear_key_dim;
+ cpu_rms_norm_bare(kh, kh, cfg.linear_key_dim, 1e-6f);
+ for (int d = 0; d < cfg.linear_key_dim; d++) kh[d] *= inv_scale;
}
// Gated delta net recurrence
@@ -4628,12 +4951,12 @@ static void fused_layer_forward(
uint16_t *dt_bias_bf16 = lc->dt_bias;
float *out_values = s_out_vals;
- memset(out_values, 0, LINEAR_TOTAL_VALUE * sizeof(float));
- int k_heads_per_v = LINEAR_NUM_V_HEADS / LINEAR_NUM_K_HEADS;
+ memset(out_values, 0, cfg.linear_total_value * sizeof(float));
+ int k_heads_per_v = cfg.linear_num_v_heads / cfg.linear_num_k_heads;
- float g_decay[LINEAR_NUM_V_HEADS];
- float beta_gate_arr[LINEAR_NUM_V_HEADS];
- for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) {
+ float g_decay[cfg.linear_num_v_heads];
+ float beta_gate_arr[cfg.linear_num_v_heads];
+ for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) {
float a_val = alpha_out[vh];
float dt_b = dt_bias_bf16 ? bf16_to_f32(dt_bias_bf16[vh]) : 0.0f;
float A_val = A_log ? expf(A_log[vh]) : 1.0f;
@@ -4645,18 +4968,18 @@ static void fused_layer_forward(
// Compute linear_layer_idx: count of non-full-attention layers before this one.
// Full attention at (layer_idx+1) % 4 == 0, i.e. layers 3,7,11,...
// linear_layer_idx = layer_idx - number_of_full_layers_at_or_before
- // = layer_idx - (layer_idx + 1) / FULL_ATTN_INTERVAL
- int linear_layer_idx = layer_idx - (layer_idx + 1) / FULL_ATTN_INTERVAL;
+ // = cfg.linear_index[layer_idx]
+ int linear_layer_idx = cfg.linear_index[layer_idx];
// GPU delta-net path (falls back to CPU if pipeline unavailable)
if (g_metal && g_metal->delta_net_step &&
- linear_layer_idx >= 0 && linear_layer_idx < NUM_LINEAR_LAYERS) {
+ linear_layer_idx >= 0 && linear_layer_idx < cfg.num_linear_layers) {
// Upload CPU-computed data to GPU scratch buffers
- memcpy([g_metal->buf_delta_q contents], lin_q, LINEAR_TOTAL_KEY * sizeof(float));
- memcpy([g_metal->buf_delta_k contents], lin_k, LINEAR_TOTAL_KEY * sizeof(float));
- memcpy([g_metal->buf_delta_v contents], lin_v, LINEAR_TOTAL_VALUE * sizeof(float));
- memcpy([g_metal->buf_delta_g_decay contents], g_decay, LINEAR_NUM_V_HEADS * sizeof(float));
- memcpy([g_metal->buf_delta_beta contents], beta_gate_arr, LINEAR_NUM_V_HEADS * sizeof(float));
+ memcpy([g_metal->buf_delta_q contents], lin_q, cfg.linear_total_key * sizeof(float));
+ memcpy([g_metal->buf_delta_k contents], lin_k, cfg.linear_total_key * sizeof(float));
+ memcpy([g_metal->buf_delta_v contents], lin_v, cfg.linear_total_value * sizeof(float));
+ memcpy([g_metal->buf_delta_g_decay contents], g_decay, cfg.linear_num_v_heads * sizeof(float));
+ memcpy([g_metal->buf_delta_beta contents], beta_gate_arr, cfg.linear_num_v_heads * sizeof(float));
id cmd_dn = [g_metal->queue commandBuffer];
id enc = [cmd_dn computeCommandEncoder];
@@ -4670,53 +4993,53 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->buf_delta_output offset:0 atIndex:6];
uint32_t khpv = (uint32_t)k_heads_per_v;
[enc setBytes:&khpv length:sizeof(khpv) atIndex:7];
- [enc dispatchThreadgroups:MTLSizeMake(LINEAR_NUM_V_HEADS, 1, 1)
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.linear_num_v_heads, 1, 1)
threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
[enc endEncoding];
[cmd_dn commit];
[cmd_dn waitUntilCompleted];
// Read back GPU result
- memcpy(out_values, [g_metal->buf_delta_output contents], LINEAR_TOTAL_VALUE * sizeof(float));
+ memcpy(out_values, [g_metal->buf_delta_output contents], cfg.linear_total_value * sizeof(float));
} else {
// CPU delta-net with Accelerate BLAS
- for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) {
+ for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) {
int kh = vh / k_heads_per_v;
float g = g_decay[vh];
float b_gate = beta_gate_arr[vh];
- float *S = la_state->ssm_state + vh * LINEAR_VALUE_DIM * LINEAR_KEY_DIM;
- float *v_h = lin_v + vh * LINEAR_VALUE_DIM;
- float *k_h = lin_k + kh * LINEAR_KEY_DIM;
+ float *S = la_state->ssm_state + vh * cfg.linear_value_dim * cfg.linear_key_dim;
+ float *v_h = lin_v + vh * cfg.linear_value_dim;
+ float *k_h = lin_k + kh * cfg.linear_key_dim;
// Step 1: Decay S *= g (BLAS sscal on entire state matrix)
- cblas_sscal(LINEAR_VALUE_DIM * LINEAR_KEY_DIM, g, S, 1);
+ cblas_sscal(cfg.linear_value_dim * cfg.linear_key_dim, g, S, 1);
// Step 2: kv_mem = S @ k (each row dot k)
// S is [VALUE_DIM x KEY_DIM] row-major, k is [KEY_DIM]
// kv_mem[vi] = sum_ki(S[vi,ki] * k[ki]) = matrix-vector: S @ k
- float kv_mem_vec[LINEAR_VALUE_DIM];
+ float kv_mem_vec[cfg.linear_value_dim];
cblas_sgemv(CblasRowMajor, CblasNoTrans,
- LINEAR_VALUE_DIM, LINEAR_KEY_DIM,
- 1.0f, S, LINEAR_KEY_DIM, k_h, 1,
+ cfg.linear_value_dim, cfg.linear_key_dim,
+ 1.0f, S, cfg.linear_key_dim, k_h, 1,
0.0f, kv_mem_vec, 1);
// Step 3: delta = (v - kv_mem) * beta, then rank-1 update S += k * delta^T
// delta[vi] = (v[vi] - kv_mem[vi]) * beta
- float delta_vec[LINEAR_VALUE_DIM];
- for (int vi = 0; vi < LINEAR_VALUE_DIM; vi++) {
+ float delta_vec[cfg.linear_value_dim];
+ for (int vi = 0; vi < cfg.linear_value_dim; vi++) {
delta_vec[vi] = (v_h[vi] - kv_mem_vec[vi]) * b_gate;
}
// S += delta @ k^T (rank-1 update: sger)
// S[vi,ki] += delta[vi] * k[ki]
- cblas_sger(CblasRowMajor, LINEAR_VALUE_DIM, LINEAR_KEY_DIM,
- 1.0f, delta_vec, 1, k_h, 1, S, LINEAR_KEY_DIM);
+ cblas_sger(CblasRowMajor, cfg.linear_value_dim, cfg.linear_key_dim,
+ 1.0f, delta_vec, 1, k_h, 1, S, cfg.linear_key_dim);
// Step 4: output = S @ q (matrix-vector multiply)
- float *q_h = lin_q + kh * LINEAR_KEY_DIM;
- float *o_h = out_values + vh * LINEAR_VALUE_DIM;
+ float *q_h = lin_q + kh * cfg.linear_key_dim;
+ float *o_h = out_values + vh * cfg.linear_value_dim;
cblas_sgemv(CblasRowMajor, CblasNoTrans,
- LINEAR_VALUE_DIM, LINEAR_KEY_DIM,
- 1.0f, S, LINEAR_KEY_DIM, q_h, 1,
+ cfg.linear_value_dim, cfg.linear_key_dim,
+ 1.0f, S, cfg.linear_key_dim, q_h, 1,
0.0f, o_h, 1);
}
}
@@ -4724,15 +5047,15 @@ static void fused_layer_forward(
// RMSNormGated
uint16_t *gated_norm_w = lc->gated_norm_w;
float *gated_out = s_gated_out;
- memset(gated_out, 0, LINEAR_TOTAL_VALUE * sizeof(float));
- for (int vh = 0; vh < LINEAR_NUM_V_HEADS; vh++) {
- float *oh = out_values + vh * LINEAR_VALUE_DIM;
- float *zh = z_out + vh * LINEAR_VALUE_DIM;
- float *gh = gated_out + vh * LINEAR_VALUE_DIM;
+ memset(gated_out, 0, cfg.linear_total_value * sizeof(float));
+ for (int vh = 0; vh < cfg.linear_num_v_heads; vh++) {
+ float *oh = out_values + vh * cfg.linear_value_dim;
+ float *zh = z_out + vh * cfg.linear_value_dim;
+ float *gh = gated_out + vh * cfg.linear_value_dim;
if (gated_norm_w) {
- cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, LINEAR_VALUE_DIM, RMS_NORM_EPS);
+ cpu_rms_norm_gated(oh, zh, gated_norm_w, gh, cfg.linear_value_dim, cfg.rms_norm_eps);
} else {
- memcpy(gh, oh, LINEAR_VALUE_DIM * sizeof(float));
+ memcpy(gh, oh, cfg.linear_value_dim * sizeof(float));
}
}
@@ -4766,11 +5089,11 @@ static void fused_layer_forward(
float *h_post = s_h_post;
float *h_mid = s_h_mid;
float *gate_scores = s_gate_scores;
- memset(gate_scores, 0, NUM_EXPERTS * sizeof(float));
+ memset(gate_scores, 0, cfg.num_experts * sizeof(float));
float *shared_gate = s_shared_gate;
- memset(shared_gate, 0, SHARED_INTERMEDIATE * sizeof(float));
+ memset(shared_gate, 0, cfg.shared_intermediate * sizeof(float));
float *shared_up = s_shared_up;
- memset(shared_up, 0, SHARED_INTERMEDIATE * sizeof(float));
+ memset(shared_up, 0, cfg.shared_intermediate * sizeof(float));
float shared_gate_score = 0.0f;
int have_moe_weights = (gate_w && gate_s && gate_b && sgw && sgs && sgb &&
@@ -4809,7 +5132,7 @@ static void fused_layer_forward(
}
// gpu_linear_attn: batch_out[6] already has the result from CMD1 gated_rms_norm
// Copy residual into GPU buffer for residual_add kernel
- memcpy([g_metal->buf_residual contents], residual, HIDDEN_DIM * sizeof(float));
+ memcpy([g_metal->buf_residual contents], residual, cfg.hidden_dim * sizeof(float));
attn_out_for_oproj = NULL;
@@ -4817,11 +5140,11 @@ static void fused_layer_forward(
// ---- GPU attention dispatches (only for full-attn layers with GPU path) ----
if (gpu_attn_fuse) {
- int fa_idx = (layer_idx + 1) / FULL_ATTN_INTERVAL - 1;
- int kv_dim = NUM_KV_HEADS * HEAD_DIM;
- int heads_per_kv = NUM_ATTN_HEADS / NUM_KV_HEADS;
- float scale = 1.0f / sqrtf((float)HEAD_DIM);
- uint32_t hd = HEAD_DIM;
+ int fa_idx = cfg.full_attn_index[layer_idx];
+ int kv_dim = cfg.num_kv_heads * cfg.head_dim;
+ int heads_per_kv = cfg.num_attn_heads / cfg.num_kv_heads;
+ float scale = 1.0f / sqrtf((float)cfg.head_dim);
+ uint32_t hd = cfg.head_dim;
uint32_t kvd = (uint32_t)kv_dim;
uint32_t sl = (uint32_t)kv->len;
uint32_t seq_stride = GPU_KV_SEQ;
@@ -4841,7 +5164,7 @@ static void fused_layer_forward(
[enc setBytes:&scale length:4 atIndex:7];
[enc setBytes:&hpkv length:4 atIndex:8];
[enc setBytes:&sl length:4 atIndex:9];
- uint32_t total_tgs = sl * NUM_ATTN_HEADS;
+ uint32_t total_tgs = sl * cfg.num_attn_heads;
[enc dispatchThreadgroups:MTLSizeMake(total_tgs, 1, 1)
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
[enc endEncoding];
@@ -4853,7 +5176,7 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->buf_attn_scores offset:0 atIndex:0];
[enc setBytes:&sl length:4 atIndex:1];
[enc setBytes:&seq_stride length:4 atIndex:2];
- [enc dispatchThreadgroups:MTLSizeMake(NUM_ATTN_HEADS, 1, 1)
+ [enc dispatchThreadgroups:MTLSizeMake(cfg.num_attn_heads, 1, 1)
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
[enc endEncoding];
}
@@ -4869,7 +5192,7 @@ static void fused_layer_forward(
[enc setBytes:&sl length:4 atIndex:5];
[enc setBytes:&seq_stride length:4 atIndex:6];
[enc setBytes:&hpkv length:4 atIndex:7];
- uint32_t total_threads = HEAD_DIM * NUM_ATTN_HEADS;
+ uint32_t total_threads = cfg.head_dim * cfg.num_attn_heads;
uint32_t tgs = (total_threads + 255) / 256;
[enc dispatchThreadgroups:MTLSizeMake(tgs, 1, 1)
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
@@ -4877,7 +5200,7 @@ static void fused_layer_forward(
}
// Enc A4: sigmoid_gate
{
- uint32_t qdim = NUM_ATTN_HEADS * HEAD_DIM;
+ uint32_t qdim = cfg.num_attn_heads * cfg.head_dim;
id enc = [cmd_fused computeCommandEncoder];
[enc setComputePipelineState:g_metal->sigmoid_gate_pipe];
[enc setBuffer:g_metal->buf_attn_out offset:0 atIndex:0];
@@ -4901,9 +5224,9 @@ static void fused_layer_forward(
id oproj_input = gpu_attn_fuse ? g_metal->buf_attn_out : g_metal->batch_out[6];
id enc = [cmd_fused computeCommandEncoder];
- uint32_t o_out_dim = HIDDEN_DIM;
+ uint32_t o_out_dim = cfg.hidden_dim;
uint32_t o_in_dim = (uint32_t)oproj_in_dim;
- uint32_t o_gs = GROUP_SIZE;
+ uint32_t o_gs = cfg.group_size;
[enc setComputePipelineState:g_metal->matvec_fast];
[enc setBuffer:g_metal->wf_buf offset:w_off atIndex:0];
[enc setBuffer:g_metal->wf_buf offset:s_off atIndex:1];
@@ -4921,7 +5244,7 @@ static void fused_layer_forward(
// ---- Enc 2: residual_add (buf_output + buf_residual -> buf_h_mid) ----
{
id enc = [cmd_fused computeCommandEncoder];
- uint32_t dim = HIDDEN_DIM;
+ uint32_t dim = cfg.hidden_dim;
[enc setComputePipelineState:g_metal->residual_add];
[enc setBuffer:g_metal->buf_residual offset:0 atIndex:0]; // a = residual
[enc setBuffer:g_metal->buf_output offset:0 atIndex:1]; // b = o_proj result
@@ -4936,7 +5259,7 @@ static void fused_layer_forward(
// ---- Enc 3: rms_norm_sum_sq (buf_h_mid -> buf_sum_sq) ----
{
id enc = [cmd_fused computeCommandEncoder];
- uint32_t dim = HIDDEN_DIM;
+ uint32_t dim = cfg.hidden_dim;
[enc setComputePipelineState:g_metal->rms_norm_sum];
[enc setBuffer:g_metal->buf_h_mid offset:0 atIndex:0];
[enc setBuffer:g_metal->buf_sum_sq offset:0 atIndex:1];
@@ -4951,8 +5274,8 @@ static void fused_layer_forward(
NSUInteger norm_off = (NSUInteger)((const char *)lc->post_attn_norm_w -
(const char *)[g_metal->wf_buf contents]);
id enc = [cmd_fused computeCommandEncoder];
- uint32_t dim = HIDDEN_DIM;
- float eps = RMS_NORM_EPS;
+ uint32_t dim = cfg.hidden_dim;
+ float eps = cfg.rms_norm_eps;
[enc setComputePipelineState:g_metal->rms_norm_apply_bf16];
[enc setBuffer:g_metal->buf_h_mid offset:0 atIndex:0]; // x
[enc setBuffer:g_metal->wf_buf offset:norm_off atIndex:1]; // weight (bf16)
@@ -4968,10 +5291,10 @@ static void fused_layer_forward(
// ---- Enc 5-8: routing + shared expert projections (read buf_input) ----
BatchMatvecSpec moe_specs[4] = {
- { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 },
- { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 },
- { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 },
- { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 },
+ { gate_w, gate_s, gate_b, gate_scores, (uint32_t)cfg.num_experts, cfg.hidden_dim, cfg.group_size, 0 },
+ { sgw, sgs, sgb, shared_gate, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 1 },
+ { suw, sus, sub, shared_up, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 2 },
+ { seg_w, seg_s, seg_b, &shared_gate_score, 1, cfg.hidden_dim, cfg.group_size, 3 },
};
// buf_input already contains h_post from Enc 4 output -- no memcpy needed
gpu_encode_batch_matvec(g_metal, cmd_fused, moe_specs, 4);
@@ -4986,11 +5309,11 @@ static void fused_layer_forward(
// Read back results
gpu_flush_batch_results(g_metal, moe_specs, 4);
// Read h_mid from GPU buffer (needed for final combine)
- memcpy(h_mid, [g_metal->buf_h_mid contents], HIDDEN_DIM * sizeof(float));
+ memcpy(h_mid, [g_metal->buf_h_mid contents], cfg.hidden_dim * sizeof(float));
// Read h_post from buf_input (needed for expert input)
- memcpy(h_post, [g_metal->buf_input contents], HIDDEN_DIM * sizeof(float));
+ memcpy(h_post, [g_metal->buf_input contents], cfg.hidden_dim * sizeof(float));
// Update hidden state to h_mid (= residual + o_proj)
- memcpy(hidden, h_mid, HIDDEN_DIM * sizeof(float));
+ memcpy(hidden, h_mid, cfg.hidden_dim * sizeof(float));
if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd2_wait += t1 - t0; }
} else {
@@ -4998,45 +5321,45 @@ static void fused_layer_forward(
// O projection
if (attn_out_for_oproj && oproj_w && oproj_s && oproj_b) {
fast_dequant_matvec(oproj_w, oproj_s, oproj_b, attn_out_for_oproj,
- attn_projected, HIDDEN_DIM, oproj_in_dim, GROUP_SIZE);
+ attn_projected, cfg.hidden_dim, oproj_in_dim, cfg.group_size);
}
// attn_out_for_oproj is static — no free needed
attn_out_for_oproj = NULL;
// Residual connection
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
hidden[i] = residual[i] + attn_projected[i];
}
// attn_projected, normed, residual are static — no free needed
- cpu_vec_copy(h_mid, hidden, HIDDEN_DIM);
+ cpu_vec_copy(h_mid, hidden, cfg.hidden_dim);
// Post-attention norm
- cpu_rms_norm(hidden, lc->post_attn_norm_w, h_post, HIDDEN_DIM, RMS_NORM_EPS);
+ cpu_rms_norm(hidden, lc->post_attn_norm_w, h_post, cfg.hidden_dim, cfg.rms_norm_eps);
// Routing + shared expert batch
if (have_moe_weights) {
BatchMatvecSpec moe_specs[4] = {
- { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 },
- { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 },
- { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 },
- { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 },
+ { gate_w, gate_s, gate_b, gate_scores, (uint32_t)cfg.num_experts, cfg.hidden_dim, cfg.group_size, 0 },
+ { sgw, sgs, sgb, shared_gate, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 1 },
+ { suw, sus, sub, shared_up, (uint32_t)cfg.shared_intermediate, cfg.hidden_dim, cfg.group_size, 2 },
+ { seg_w, seg_s, seg_b, &shared_gate_score, 1, cfg.hidden_dim, cfg.group_size, 3 },
};
- fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4);
+ fast_batch_matvec(h_post, cfg.hidden_dim, moe_specs, 4);
}
if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd2_encode += t1 - t0; }
}
// ---- Softmax + top-K (CPU) ----
if (g_timing_enabled) { t0 = now_ms(); }
- cpu_softmax(gate_scores, NUM_EXPERTS);
+ cpu_softmax(gate_scores, cfg.num_experts);
int expert_indices[64];
float expert_weights[64];
- cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights);
+ cpu_topk(gate_scores, cfg.num_experts, K, expert_indices, expert_weights);
cpu_normalize_weights(expert_weights, K);
if (g_freq_tracking) {
for (int k = 0; k < K; k++) {
- g_expert_freq[layer_idx][expert_indices[k]]++;
+ FREQ(layer_idx, expert_indices[k])++;
}
if (layer_idx == 0) g_freq_total_tokens++;
}
@@ -5062,7 +5385,7 @@ static void fused_layer_forward(
int32_t ki = (K > MAX_K) ? MAX_K : K;
fwrite(&li, sizeof(int32_t), 1, g_routing_log);
fwrite(&ki, sizeof(int32_t), 1, g_routing_log);
- fwrite(hidden, sizeof(float), HIDDEN_DIM, g_routing_log);
+ fwrite(hidden, sizeof(float), cfg.hidden_dim, g_routing_log);
fwrite(expert_indices, sizeof(int32_t), ki, g_routing_log);
g_routing_log_samples++;
}
@@ -5070,9 +5393,9 @@ static void fused_layer_forward(
// ---- Parallel pread + GPU experts ----
if (g_timing_enabled) { t0 = now_ms(); }
float *moe_out = s_moe_out;
- memset(moe_out, 0, HIDDEN_DIM * sizeof(float));
+ memset(moe_out, 0, cfg.hidden_dim * sizeof(float));
float *shared_out = s_shared_out;
- memset(shared_out, 0, HIDDEN_DIM * sizeof(float));
+ memset(shared_out, 0, cfg.hidden_dim * sizeof(float));
int actual_K = (K > MAX_K) ? MAX_K : K;
@@ -5207,8 +5530,8 @@ static void fused_layer_forward(
for (int k = 0; k < actual_K; k++) {
int found = 0;
- for (int p = 0; p < g_pred_count[layer_idx]; p++) {
- if (expert_indices[k] == g_pred_experts[layer_idx][p] &&
+ for (int p = 0; p < PRED_COUNT(layer_idx); p++) {
+ if (expert_indices[k] == PRED_EXPERT(layer_idx, p) &&
g_async_pread.valid[p]) {
// Hit! This expert was pre-loaded into buf_B[p]
expert_bufs[k] = g_metal->buf_multi_expert_data_B[p];
@@ -5276,11 +5599,11 @@ static void fused_layer_forward(
}
// Shared expert prep (doesn't need expert data — can overlap with async pread)
- memcpy([g_metal->buf_multi_expert_input contents], h_post, HIDDEN_DIM * sizeof(float));
+ memcpy([g_metal->buf_multi_expert_input contents], h_post, cfg.hidden_dim * sizeof(float));
memcpy([g_metal->buf_shared_gate contents], shared_gate,
- SHARED_INTERMEDIATE * sizeof(float));
+ cfg.shared_intermediate * sizeof(float));
memcpy([g_metal->buf_shared_up contents], shared_up,
- SHARED_INTERMEDIATE * sizeof(float));
+ cfg.shared_intermediate * sizeof(float));
// Wait for non-prediction async pread to complete
if (!pred_started && g_async_pread.active) {
@@ -5296,10 +5619,10 @@ static void fused_layer_forward(
// MUST happen AFTER the prediction hit check above (which reads g_pred_experts).
if (g_pred_enabled && g_pred_generating) {
for (int k = 0; k < actual_K; k++) {
- g_pred_experts[layer_idx][k] = expert_indices[k];
+ PRED_EXPERT(layer_idx, k) = expert_indices[k];
}
- g_pred_count[layer_idx] = actual_K;
- if (layer_idx == NUM_LAYERS - 1) {
+ PRED_COUNT(layer_idx) = actual_K;
+ if (layer_idx == cfg.num_layers - 1) {
g_pred_valid = 1;
}
}
@@ -5323,7 +5646,7 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->buf_shared_gate offset:0 atIndex:0];
[enc setBuffer:g_metal->buf_shared_up offset:0 atIndex:1];
[enc setBuffer:g_metal->buf_shared_act offset:0 atIndex:2];
- uint32_t dim = SHARED_INTERMEDIATE;
+ uint32_t dim = cfg.shared_intermediate;
[enc setBytes:&dim length:4 atIndex:3];
uint32_t swiglu_tgs = (dim + 255) / 256;
[enc dispatchThreadgroups:MTLSizeMake(swiglu_tgs, 1, 1)
@@ -5336,7 +5659,7 @@ static void fused_layer_forward(
gpu_encode_dequant_matvec_with_io_bufs(
g_metal, cmd_experts, sdw, sds, sdb,
g_metal->buf_shared_act, g_metal->buf_shared_out,
- HIDDEN_DIM, SHARED_INTERMEDIATE, GROUP_SIZE);
+ cfg.hidden_dim, cfg.shared_intermediate, cfg.group_size);
}
// Step 4: GPU-side combine + residual + norm (if not last layer)
@@ -5355,7 +5678,7 @@ static void fused_layer_forward(
g_metal->rms_norm_sum &&
g_metal->rms_norm_apply_bf16 &&
g_metal->wf_buf &&
- layer_idx < NUM_LAYERS - 1 &&
+ layer_idx < cfg.num_layers - 1 &&
layer_cache[layer_idx + 1].input_norm_w != NULL);
if (gpu_combine) {
@@ -5385,7 +5708,7 @@ static void fused_layer_forward(
[enc setBuffer:g_metal->buf_multi_expert_out[k] offset:0 atIndex:(3 + k)];
}
[enc setBuffer:g_metal->buf_combine_params offset:0 atIndex:11]; // params
- uint32_t dim = HIDDEN_DIM;
+ uint32_t dim = cfg.hidden_dim;
uint32_t k_val = (uint32_t)actual_K;
[enc setBytes:&dim length:4 atIndex:12];
[enc setBytes:&k_val length:4 atIndex:13];
@@ -5398,7 +5721,7 @@ static void fused_layer_forward(
// Enc C2: rms_norm_sum_sq (buf_moe_hidden -> buf_cmd3_sum_sq)
{
id enc = [cmd_experts computeCommandEncoder];
- uint32_t dim = HIDDEN_DIM;
+ uint32_t dim = cfg.hidden_dim;
[enc setComputePipelineState:g_metal->rms_norm_sum];
[enc setBuffer:g_metal->buf_moe_hidden offset:0 atIndex:0];
[enc setBuffer:g_metal->buf_cmd3_sum_sq offset:0 atIndex:1];
@@ -5414,8 +5737,8 @@ static void fused_layer_forward(
NSUInteger norm_off = (NSUInteger)((const char *)next_norm_w -
(const char *)[g_metal->wf_buf contents]);
id enc = [cmd_experts computeCommandEncoder];
- uint32_t dim = HIDDEN_DIM;
- float eps = RMS_NORM_EPS;
+ uint32_t dim = cfg.hidden_dim;
+ float eps = cfg.rms_norm_eps;
[enc setComputePipelineState:g_metal->rms_norm_apply_bf16];
[enc setBuffer:g_metal->buf_moe_hidden offset:0 atIndex:0]; // x
[enc setBuffer:g_metal->wf_buf offset:norm_off atIndex:1]; // weight (bf16)
@@ -5449,7 +5772,7 @@ static void fused_layer_forward(
g_deferred.layer_idx = layer_idx;
if (!gpu_combine) {
// Only need to save h_mid for CPU-side combine path
- memcpy(g_deferred.h_mid, h_mid, HIDDEN_DIM * sizeof(float));
+ memcpy(g_deferred.h_mid, h_mid, cfg.hidden_dim * sizeof(float));
}
for (int k = 0; k < actual_K; k++) {
g_deferred.expert_weights[k] = expert_weights[k];
@@ -5464,7 +5787,7 @@ static void fused_layer_forward(
} else if (packed_fd >= 0) {
// CPU fallback for experts
size_t esz = active_expert_size();
- float *expert_out_cpu = malloc(HIDDEN_DIM * sizeof(float));
+ float *expert_out_cpu = malloc(cfg.hidden_dim * sizeof(float));
for (int k = 0; k < K; k++) {
int eidx = expert_indices[k];
off_t expert_offset = (off_t)eidx * esz;
@@ -5479,63 +5802,63 @@ static void fused_layer_forward(
// CPU fallback offsets — use 4-bit layout (2-bit CPU path not yet implemented)
uint32_t *gw = (uint32_t *)expert_data;
- uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_S_OFF_2 : 2097152));
- uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? GATE_B_OFF_2 : 2228224));
- uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? UP_W_OFF_2 : 2359296));
- uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_S_OFF_2 : 4456448));
- uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? UP_B_OFF_2 : 4587520));
- uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? DOWN_W_OFF_2 : 4718592));
- uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_S_OFF_2 : 6815744));
- uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? DOWN_B_OFF_2 : 6946816));
-
- float *gate_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float));
- float *up_proj_out = malloc(MOE_INTERMEDIATE * sizeof(float));
- float *act_out = malloc(MOE_INTERMEDIATE * sizeof(float));
+ uint16_t *gs_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_s_off_2 : cfg.gate_s_off_4));
+ uint16_t *gb_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.gate_b_off_2 : cfg.gate_b_off_4));
+ uint32_t *uw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.up_w_off_2 : cfg.up_w_off_4));
+ uint16_t *us_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_s_off_2 : cfg.up_s_off_4));
+ uint16_t *ub_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.up_b_off_2 : cfg.up_b_off_4));
+ uint32_t *dw = (uint32_t *)((char *)expert_data + (g_use_2bit ? cfg.down_w_off_2 : cfg.down_w_off_4));
+ uint16_t *ds_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_s_off_2 : cfg.down_s_off_4));
+ uint16_t *db_p = (uint16_t *)((char *)expert_data + (g_use_2bit ? cfg.down_b_off_2 : cfg.down_b_off_4));
+
+ float *gate_proj_out = malloc(cfg.moe_intermediate * sizeof(float));
+ float *up_proj_out = malloc(cfg.moe_intermediate * sizeof(float));
+ float *act_out = malloc(cfg.moe_intermediate * sizeof(float));
cpu_dequant_matvec(gw, gs_p, gb_p, h_post, gate_proj_out,
- MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE);
+ cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size);
cpu_dequant_matvec(uw, us_p, ub_p, h_post, up_proj_out,
- MOE_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE);
- cpu_swiglu(gate_proj_out, up_proj_out, act_out, MOE_INTERMEDIATE);
+ cfg.moe_intermediate, cfg.hidden_dim, cfg.group_size);
+ cpu_swiglu(gate_proj_out, up_proj_out, act_out, cfg.moe_intermediate);
cpu_dequant_matvec(dw, ds_p, db_p, act_out, expert_out_cpu,
- HIDDEN_DIM, MOE_INTERMEDIATE, GROUP_SIZE);
+ cfg.hidden_dim, cfg.moe_intermediate, cfg.group_size);
free(gate_proj_out);
free(up_proj_out);
free(act_out);
free(expert_data);
- cpu_vec_madd(moe_out, expert_out_cpu, expert_weights[k], HIDDEN_DIM);
+ cpu_vec_madd(moe_out, expert_out_cpu, expert_weights[k], cfg.hidden_dim);
}
free(expert_out_cpu);
// CPU shared expert
- float *shared_act = calloc(SHARED_INTERMEDIATE, sizeof(float));
- cpu_swiglu(shared_gate, shared_up, shared_act, SHARED_INTERMEDIATE);
+ float *shared_act = calloc(cfg.shared_intermediate, sizeof(float));
+ cpu_swiglu(shared_gate, shared_up, shared_act, cfg.shared_intermediate);
if (sdw && sds && sdb) {
cpu_dequant_matvec(sdw, sds, sdb, shared_act, shared_out,
- HIDDEN_DIM, SHARED_INTERMEDIATE, GROUP_SIZE);
+ cfg.hidden_dim, cfg.shared_intermediate, cfg.group_size);
}
free(shared_act);
} else {
// No experts available -- still need shared expert
- float *shared_act = calloc(SHARED_INTERMEDIATE, sizeof(float));
- cpu_swiglu(shared_gate, shared_up, shared_act, SHARED_INTERMEDIATE);
+ float *shared_act = calloc(cfg.shared_intermediate, sizeof(float));
+ cpu_swiglu(shared_gate, shared_up, shared_act, cfg.shared_intermediate);
if (sdw && sds && sdb) {
fast_dequant_matvec(sdw, sds, sdb, shared_act, shared_out,
- HIDDEN_DIM, SHARED_INTERMEDIATE, GROUP_SIZE);
+ cfg.hidden_dim, cfg.shared_intermediate, cfg.group_size);
}
free(shared_act);
}
// ---- Shared expert gate ----
float shared_weight = cpu_sigmoid(shared_gate_score);
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
shared_out[i] *= shared_weight;
}
// ---- Final combine: hidden = h_mid + moe_out + shared_out ----
- for (int i = 0; i < HIDDEN_DIM; i++) {
+ for (int i = 0; i < cfg.hidden_dim; i++) {
hidden[i] = h_mid[i] + moe_out[i] + shared_out[i];
}
@@ -5574,14 +5897,14 @@ static void freq_print_analysis(int K) {
// Per-layer analysis
int experts_for_80_total = 0; // sum across layers for overall estimate
- for (int l = 0; l < NUM_LAYERS; l++) {
+ for (int l = 0; l < cfg.num_layers; l++) {
// Count unique experts and sort frequencies descending
- int sorted[NUM_EXPERTS];
- memcpy(sorted, g_expert_freq[l], NUM_EXPERTS * sizeof(int));
- qsort(sorted, NUM_EXPERTS, sizeof(int), freq_cmp_desc);
+ int sorted[cfg.num_experts];
+ memcpy(sorted, &FREQ(l, 0), cfg.num_experts * sizeof(int));
+ qsort(sorted, cfg.num_experts, sizeof(int), freq_cmp_desc);
int unique = 0;
- for (int e = 0; e < NUM_EXPERTS; e++) {
+ for (int e = 0; e < cfg.num_experts; e++) {
if (sorted[e] > 0) unique++;
}
@@ -5589,7 +5912,7 @@ static void freq_print_analysis(int K) {
int cum = 0;
int top10_cov = 0, top30_cov = 0, top60_cov = 0;
int n_for_50 = 0, n_for_80 = 0, n_for_90 = 0;
- for (int e = 0; e < NUM_EXPERTS; e++) {
+ for (int e = 0; e < cfg.num_experts; e++) {
cum += sorted[e];
if (e == 9) top10_cov = cum;
if (e == 29) top30_cov = cum;
@@ -5615,10 +5938,10 @@ static void freq_print_analysis(int K) {
}
// Overall summary: average experts needed for 80% across all layers
- double avg_experts_80 = (double)experts_for_80_total / NUM_LAYERS;
+ double avg_experts_80 = (double)experts_for_80_total / cfg.num_layers;
// Expert size in GB: each expert is active_expert_size() bytes
double expert_gb = (double)active_expert_size() / (1024.0 * 1024.0 * 1024.0);
- double total_pin_gb = avg_experts_80 * NUM_LAYERS * expert_gb;
+ double total_pin_gb = avg_experts_80 * cfg.num_layers * expert_gb;
fprintf(stderr, "\n--- Overall Summary ---\n");
fprintf(stderr, "To achieve 80%% hit rate across all layers, need %d experts pinned "
@@ -5626,7 +5949,7 @@ static void freq_print_analysis(int K) {
experts_for_80_total, avg_experts_80, total_pin_gb);
fprintf(stderr, "Expert size: %zu bytes (%.3f MB), %d layers x %d experts = %d total\n",
active_expert_size(), (double)active_expert_size() / (1024.0 * 1024.0),
- NUM_LAYERS, NUM_EXPERTS, NUM_LAYERS * NUM_EXPERTS);
+ cfg.num_layers, cfg.num_experts, cfg.num_layers * cfg.num_experts);
}
#ifndef CHAT_MODE
@@ -5946,17 +6269,17 @@ static void sse_send_done(int fd, const char *request_id) {
static void sync_cpu_to_gpu_delta_state_serve(void **layer_states) {
if (!g_metal || !g_metal->delta_net_step || !layer_states) return;
int li = 0;
- for (int i = 0; i < NUM_LAYERS; i++) {
- if ((i + 1) % FULL_ATTN_INTERVAL == 0) continue;
+ for (int i = 0; i < cfg.num_layers; i++) {
+ if (cfg.is_full_attn[i]) continue;
if (!layer_states[i]) { li++; continue; }
LinearAttnState *la = (LinearAttnState *)layer_states[i];
- if (li < NUM_LINEAR_LAYERS) {
+ if (li < cfg.num_linear_layers) {
if (g_metal->buf_delta_state[li] && la->ssm_state)
memcpy([g_metal->buf_delta_state[li] contents], la->ssm_state,
- LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM * sizeof(float));
+ cfg.linear_num_v_heads * cfg.linear_value_dim * cfg.linear_key_dim * sizeof(float));
if (g_metal->buf_conv_state[li] && la->conv_state)
memcpy([g_metal->buf_conv_state[li] contents], la->conv_state,
- (CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM * sizeof(float));
+ (cfg.conv_kernel_size - 1) * cfg.linear_conv_dim * sizeof(float));
}
li++;
}
@@ -5998,7 +6321,7 @@ static void serve_loop(
static uint64_t req_counter = 0;
// ---- System prompt cache: prefill system prompt once at startup ----
- // Tokenize the system prompt and run it through all 60 layers.
+ // Tokenize the system prompt and run it through all 40 layers.
// Save the resulting KV cache + linear attention state as a snapshot.
// On each request, restore the snapshot instead of re-prefilling.
fprintf(stderr, "[serve] Pre-caching system prompt...\n");
@@ -6008,22 +6331,22 @@ static void serve_loop(
// Pre-embed all system prompt tokens
float *sys_embed_batch = NULL;
if (sys_pt->count > 1) {
- sys_embed_batch = malloc((size_t)sys_pt->count * HIDDEN_DIM * sizeof(float));
+ sys_embed_batch = malloc((size_t)sys_pt->count * cfg.hidden_dim * sizeof(float));
for (int i = 0; i < sys_pt->count; i++) {
- embed_lookup(wf, sys_pt->ids[i], sys_embed_batch + (size_t)i * HIDDEN_DIM);
+ embed_lookup(wf, sys_pt->ids[i], sys_embed_batch + (size_t)i * cfg.hidden_dim);
}
}
// Intermediate system prompt tokens: discard last-layer expert output
for (int i = 0; i < sys_pt->count - 1; i++) {
cache_telemetry_note_token();
if (sys_embed_batch) {
- memcpy(hidden, sys_embed_batch + (size_t)i * HIDDEN_DIM,
- HIDDEN_DIM * sizeof(float));
+ memcpy(hidden, sys_embed_batch + (size_t)i * cfg.hidden_dim,
+ cfg.hidden_dim * sizeof(float));
} else {
embed_lookup(wf, sys_pt->ids[i], hidden);
}
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6038,13 +6361,13 @@ static void serve_loop(
{
cache_telemetry_note_token();
if (sys_embed_batch) {
- memcpy(hidden, sys_embed_batch + (size_t)(sys_pt->count - 1) * HIDDEN_DIM,
- HIDDEN_DIM * sizeof(float));
+ memcpy(hidden, sys_embed_batch + (size_t)(sys_pt->count - 1) * cfg.hidden_dim,
+ cfg.hidden_dim * sizeof(float));
} else {
embed_lookup(wf, sys_pt->ids[0], hidden);
}
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6069,20 +6392,20 @@ static void serve_loop(
float *v_snapshot;
int len;
} KVSnapshot;
- KVSnapshot kv_snapshots[NUM_LAYERS];
+ KVSnapshot kv_snapshots[cfg.num_layers];
memset(kv_snapshots, 0, sizeof(kv_snapshots));
// Linear attention snapshots
- float *la_conv_snapshots[NUM_LAYERS];
- float *la_ssm_snapshots[NUM_LAYERS];
+ float *la_conv_snapshots[cfg.num_layers];
+ float *la_ssm_snapshots[cfg.num_layers];
memset(la_conv_snapshots, 0, sizeof(la_conv_snapshots));
memset(la_ssm_snapshots, 0, sizeof(la_ssm_snapshots));
- size_t kv_dim = NUM_KV_HEADS * HEAD_DIM;
- size_t conv_state_size = (CONV_KERNEL_SIZE - 1) * LINEAR_CONV_DIM * sizeof(float);
- size_t ssm_state_size = LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM * LINEAR_KEY_DIM * sizeof(float);
+ size_t kv_dim = cfg.num_kv_heads * cfg.head_dim;
+ size_t conv_state_size = (cfg.conv_kernel_size - 1) * cfg.linear_conv_dim * sizeof(float);
+ size_t ssm_state_size = cfg.linear_num_v_heads * cfg.linear_value_dim * cfg.linear_key_dim * sizeof(float);
- for (int i = 0; i < NUM_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_layers; i++) {
if (kv_caches[i]) {
size_t sz = sys_pos * kv_dim * sizeof(float);
kv_snapshots[i].k_snapshot = malloc(sz);
@@ -6100,19 +6423,19 @@ static void serve_loop(
}
}
// Also snapshot GPU delta-net state
- void *gpu_delta_snapshots[NUM_LINEAR_LAYERS];
- void *gpu_conv_snapshots[NUM_LINEAR_LAYERS];
- memset(gpu_delta_snapshots, 0, sizeof(gpu_delta_snapshots));
- memset(gpu_conv_snapshots, 0, sizeof(gpu_conv_snapshots));
+ void **gpu_delta_snapshots = calloc(cfg.num_linear_layers, sizeof(void *));
+ void **gpu_conv_snapshots = calloc(cfg.num_linear_layers, sizeof(void *));
+ // already zeroed by calloc
+ // already zeroed by calloc
if (g_metal && g_metal->delta_net_step) {
- for (int i = 0; i < NUM_LINEAR_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_linear_layers; i++) {
if (g_metal->buf_delta_state[i]) {
- size_t sz = 64*128*128*sizeof(float);
+ size_t sz = (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float);
gpu_delta_snapshots[i] = malloc(sz);
memcpy(gpu_delta_snapshots[i], [g_metal->buf_delta_state[i] contents], sz);
}
if (g_metal->buf_conv_state[i]) {
- size_t sz = 3*12288*sizeof(float);
+ size_t sz = (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float);
gpu_conv_snapshots[i] = malloc(sz);
memcpy(gpu_conv_snapshots[i], [g_metal->buf_conv_state[i] contents], sz);
}
@@ -6156,7 +6479,7 @@ static void serve_loop(
"Access-Control-Allow-Origin: *\r\n"
"Connection: close\r\n"
"\r\n"
- "{\"status\":\"ok\",\"model\":\"qwen3.5-397b-a17b\"}\n";
+ "{\"status\":\"ok\",\"model\":\"qwen3.5-35b-a3b\"}\n";
http_write_str(client_fd, resp);
free(reqbuf); close(client_fd);
continue;
@@ -6170,7 +6493,7 @@ static void serve_loop(
"Access-Control-Allow-Origin: *\r\n"
"Connection: close\r\n"
"\r\n"
- "{\"object\":\"list\",\"data\":[{\"id\":\"qwen3.5-397b-a17b\","
+ "{\"object\":\"list\",\"data\":[{\"id\":\"qwen3.5-35b-a3b\","
"\"object\":\"model\",\"owned_by\":\"local\"}]}\n";
http_write_str(client_fd, resp);
free(reqbuf); close(client_fd);
@@ -6247,7 +6570,7 @@ static void serve_loop(
// ---- Restore state from system prompt snapshot ----
// Instead of resetting to zero, restore to the cached system prompt state.
// This skips re-prefilling the system prompt tokens (~20 tokens, ~6s saved).
- for (int i = 0; i < NUM_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_layers; i++) {
if (kv_caches[i] && kv_snapshots[i].k_snapshot) {
size_t sz = sys_prompt_len * kv_dim * sizeof(float);
memcpy(kv_caches[i]->k_cache, kv_snapshots[i].k_snapshot, sz);
@@ -6255,8 +6578,8 @@ static void serve_loop(
kv_caches[i]->len = kv_snapshots[i].len;
// Also restore GPU KV mirror
if (g_metal) {
- int fa_idx = (i + 1) / FULL_ATTN_INTERVAL - 1;
- if (fa_idx >= 0 && fa_idx < NUM_FULL_ATTN_LAYERS) {
+ int fa_idx = cfg.full_attn_index[i];
+ if (fa_idx >= 0 && fa_idx < cfg.num_full_attn_layers) {
memcpy([g_metal->buf_kv_k[fa_idx] contents],
kv_snapshots[i].k_snapshot, sz);
memcpy([g_metal->buf_kv_v[fa_idx] contents],
@@ -6278,13 +6601,13 @@ static void serve_loop(
}
// Restore GPU delta-net state
if (g_metal && g_metal->delta_net_step) {
- for (int i = 0; i < NUM_LINEAR_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_linear_layers; i++) {
if (gpu_delta_snapshots[i] && g_metal->buf_delta_state[i])
memcpy([g_metal->buf_delta_state[i] contents],
- gpu_delta_snapshots[i], 64*128*128*sizeof(float));
+ gpu_delta_snapshots[i], (size_t)cfg.linear_num_v_heads*cfg.linear_value_dim*cfg.linear_key_dim*sizeof(float));
if (gpu_conv_snapshots[i] && g_metal->buf_conv_state[i])
memcpy([g_metal->buf_conv_state[i] contents],
- gpu_conv_snapshots[i], 3*12288*sizeof(float));
+ gpu_conv_snapshots[i], (cfg.conv_kernel_size-1)*(size_t)cfg.linear_conv_dim*sizeof(float));
}
} else {
reset_delta_net_state();
@@ -6308,22 +6631,22 @@ static void serve_loop(
// Pre-embed all request tokens
float *serve_embed_batch = NULL;
if (pt->count > 1) {
- serve_embed_batch = malloc((size_t)pt->count * HIDDEN_DIM * sizeof(float));
+ serve_embed_batch = malloc((size_t)pt->count * cfg.hidden_dim * sizeof(float));
for (int i = 0; i < pt->count; i++) {
- embed_lookup(wf, pt->ids[i], serve_embed_batch + (size_t)i * HIDDEN_DIM);
+ embed_lookup(wf, pt->ids[i], serve_embed_batch + (size_t)i * cfg.hidden_dim);
}
}
// Intermediate prefill tokens: discard last-layer expert output
for (int i = 0; i < pt->count - 1; i++) {
cache_telemetry_note_token();
if (serve_embed_batch) {
- memcpy(hidden, serve_embed_batch + (size_t)i * HIDDEN_DIM,
- HIDDEN_DIM * sizeof(float));
+ memcpy(hidden, serve_embed_batch + (size_t)i * cfg.hidden_dim,
+ cfg.hidden_dim * sizeof(float));
} else {
embed_lookup(wf, pt->ids[i], hidden);
}
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6338,13 +6661,13 @@ static void serve_loop(
{
cache_telemetry_note_token();
if (serve_embed_batch) {
- memcpy(hidden, serve_embed_batch + (size_t)(pt->count - 1) * HIDDEN_DIM,
- HIDDEN_DIM * sizeof(float));
+ memcpy(hidden, serve_embed_batch + (size_t)(pt->count - 1) * cfg.hidden_dim,
+ cfg.hidden_dim * sizeof(float));
} else {
embed_lookup(wf, pt->ids[0], hidden);
}
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6362,13 +6685,13 @@ static void serve_loop(
// ---- Final norm + LM head for first token ----
if (final_norm_w) {
- float *normed = malloc(HIDDEN_DIM * sizeof(float));
- cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS);
- memcpy(hidden, normed, HIDDEN_DIM * sizeof(float));
+ float *normed = malloc(cfg.hidden_dim * sizeof(float));
+ cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps);
+ memcpy(hidden, normed, cfg.hidden_dim * sizeof(float));
free(normed);
}
lm_head_forward(wf, hidden, logits);
- int next_token = cpu_argmax(logits, VOCAB_SIZE);
+ int next_token = cpu_argmax(logits, cfg.vocab_size);
// ---- Auto-regressive generation with SSE streaming ----
if (g_pred_enabled) {
@@ -6384,12 +6707,12 @@ static void serve_loop(
int gen_resp_len = 0;
for (int gen = 0; gen < max_gen; gen++) {
- if (next_token == EOS_TOKEN_1 || next_token == EOS_TOKEN_2) {
+ if (next_token == cfg.eos_token_ids[0] || next_token == cfg.eos_token_ids[1]) {
// Feed EOS through the model so session state includes it
cache_telemetry_note_token();
embed_lookup(wf, next_token, hidden);
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6403,12 +6726,12 @@ static void serve_loop(
}
// Think budget enforcement
- if (next_token == THINK_START_TOKEN) in_think = 1;
- if (next_token == THINK_END_TOKEN) in_think = 0;
+ if (next_token == cfg.think_start_token) in_think = 1;
+ if (next_token == cfg.think_end_token) in_think = 0;
if (in_think) {
think_tokens++;
if (g_think_budget > 0 && think_tokens >= g_think_budget) {
- next_token = THINK_END_TOKEN; // force end thinking
+ next_token = cfg.think_end_token; // force end thinking
in_think = 0;
}
}
@@ -6430,8 +6753,8 @@ static void serve_loop(
// Generate next
cache_telemetry_note_token();
embed_lookup(wf, next_token, hidden);
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6443,13 +6766,13 @@ static void serve_loop(
pos++;
if (final_norm_w) {
- float *normed = malloc(HIDDEN_DIM * sizeof(float));
- cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS);
- memcpy(hidden, normed, HIDDEN_DIM * sizeof(float));
+ float *normed = malloc(cfg.hidden_dim * sizeof(float));
+ cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps);
+ memcpy(hidden, normed, cfg.hidden_dim * sizeof(float));
free(normed);
}
lm_head_forward(wf, hidden, logits);
- next_token = cpu_argmax(logits, VOCAB_SIZE);
+ next_token = cpu_argmax(logits, cfg.vocab_size);
}
sse_send_done(client_fd, request_id);
@@ -6523,14 +6846,14 @@ static void print_usage(const char *prog) {
int main(int argc, char **argv) {
@autoreleasepool {
- const char *model_path = MODEL_PATH_DEFAULT;
+ const char *model_path = getenv("FLASH_MOE_MODEL");
const char *weights_path = NULL;
const char *manifest_path = NULL;
const char *vocab_path = NULL;
const char *prompt_tokens_path = NULL;
const char *prompt_text = NULL;
int max_tokens = 20;
- int K = 4;
+ int K = 8;
int cache_entries = 0; // default 0: trust OS page cache (38% faster than Metal LRU)
int malloc_cache_entries = 0; // 0 = disabled (override with --malloc-cache)
int serve_port = 0; // 0 = disabled, >0 = HTTP serve mode
@@ -6596,6 +6919,11 @@ int main(int argc, char **argv) {
}
}
+ // ---- Load model configuration from HF config.json ----
+ load_model_config(model_path ? model_path : "");
+ alloc_tracking_arrays();
+ g_deferred.h_mid = calloc(cfg.hidden_dim, sizeof(float));
+
// Build default paths
char default_weights[1024], default_manifest[1024], default_vocab[1024];
@@ -6648,7 +6976,8 @@ int main(int argc, char **argv) {
g_expert_cache = expert_cache_new(g_metal->device, cache_entries);
}
- printf("=== Qwen3.5-397B-A17B Metal Inference Engine ===\n");
+ printf("=== Flash-MoE Metal Inference Engine ===\n");
+ printf("Config: %s/config.json\n", cfg.model_path);
printf("Model: %s\n", model_path);
printf("Weights: %s\n", weights_path);
printf("Manifest: %s\n", manifest_path);
@@ -6741,16 +7070,16 @@ int main(int argc, char **argv) {
// Seen-expert bitset tracks which (layer, expert) pairs have been read before.
// First read goes through cold fd (no page cache pollution).
// Subsequent reads go through warm fd (page cache hit = 32 GB/s vs 5.5 GB/s).
- int layer_fds[NUM_LAYERS];
- int layer_fds_cold[NUM_LAYERS];
- void *layer_mmaps[NUM_LAYERS];
- size_t layer_mmap_sizes[NUM_LAYERS];
+ int *layer_fds = calloc(cfg.num_layers, sizeof(int));
+ int *layer_fds_cold = calloc(cfg.num_layers, sizeof(int));
+ void **layer_mmaps = calloc(cfg.num_layers, sizeof(void *));
+ size_t *layer_mmap_sizes = calloc(cfg.num_layers, sizeof(size_t));
int expert_layers_available = 0;
// Reset the global seen-expert bitset
- memset(g_expert_seen, 0, sizeof(g_expert_seen));
+ memset(g_expert_seen, 0, cfg.num_layers * ((cfg.num_experts + 7) / 8));
- for (int i = 0; i < NUM_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_layers; i++) {
char path[1024];
snprintf(path, sizeof(path), "%s/%s/layer_%02d.bin", model_path,
g_use_2bit ? "packed_experts_2bit" : "packed_experts", i);
@@ -6777,7 +7106,7 @@ int main(int argc, char **argv) {
}
}
}
- printf("[experts] %d/%d packed layer files available (mmap'd)\n", expert_layers_available, NUM_LAYERS);
+ printf("[experts] %d/%d packed layer files available (mmap'd)\n", expert_layers_available, cfg.num_layers);
// ---- LZ4 compressed experts: auto-detect and load ----
{
@@ -6785,16 +7114,16 @@ int main(int argc, char **argv) {
snprintf(lz4_probe, sizeof(lz4_probe), "%s/packed_experts_lz4/layer_00.bin", model_path);
if (!g_use_2bit && access(lz4_probe, R_OK) == 0) {
int lz4_layers = 0;
- for (int i = 0; i < NUM_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_layers; i++) {
char lz4_path[1024];
snprintf(lz4_path, sizeof(lz4_path), "%s/packed_experts_lz4/layer_%02d.bin", model_path, i);
int lz4_fd = open(lz4_path, O_RDONLY);
if (lz4_fd >= 0) {
- // Load index header (512 entries × 16 bytes = 8KB)
- g_lz4_index[i] = malloc(NUM_EXPERTS * sizeof(LZ4IndexEntry));
+ // Load index header (cfg.num_experts entries × 16 bytes)
+ g_lz4_index[i] = malloc(cfg.num_experts * sizeof(LZ4IndexEntry));
ssize_t nr = pread(lz4_fd, g_lz4_index[i],
- NUM_EXPERTS * sizeof(LZ4IndexEntry), 0);
- if (nr == NUM_EXPERTS * (ssize_t)sizeof(LZ4IndexEntry)) {
+ cfg.num_experts * sizeof(LZ4IndexEntry), 0);
+ if (nr == cfg.num_experts * (ssize_t)sizeof(LZ4IndexEntry)) {
// Replace the raw fd with the LZ4 fd
close(layer_fds[i]);
layer_fds[i] = lz4_fd;
@@ -6811,10 +7140,10 @@ int main(int argc, char **argv) {
g_use_lz4 = 1;
// Allocate compressed read buffers (one per expert slot)
for (int k = 0; k < MAX_K; k++) {
- g_lz4_comp_bufs[k] = malloc(EXPERT_SIZE + 4096);
+ g_lz4_comp_bufs[k] = malloc(cfg.expert_size_4bit + 4096);
}
printf("[lz4] %d/%d layers using LZ4 compressed experts\n",
- lz4_layers, NUM_LAYERS);
+ lz4_layers, cfg.num_layers);
}
}
}
@@ -6827,7 +7156,7 @@ int main(int argc, char **argv) {
// Warm page cache hint
if (expert_layers_available > 0) {
double t_warm = now_ms();
- for (int i = 0; i < NUM_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_layers; i++) {
if (layer_fds[i] >= 0) {
char dummy[4096];
pread(layer_fds[i], dummy, sizeof(dummy), 0);
@@ -6837,11 +7166,11 @@ int main(int argc, char **argv) {
}
// ---- Allocate per-layer state ----
- void **layer_states = calloc(NUM_LAYERS, sizeof(void *));
- KVCache **kv_caches = calloc(NUM_LAYERS, sizeof(KVCache *));
+ void **layer_states = calloc(cfg.num_layers, sizeof(void *));
+ KVCache **kv_caches = calloc(cfg.num_layers, sizeof(KVCache *));
- for (int i = 0; i < NUM_LAYERS; i++) {
- int is_full = ((i + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int i = 0; i < cfg.num_layers; i++) {
+ int is_full = cfg.is_full_attn[i];
if (is_full) {
kv_caches[i] = kv_cache_new();
} else {
@@ -6853,8 +7182,8 @@ int main(int argc, char **argv) {
printf("[init] Setup: %.1f ms\n\n", t_init - t0);
// ---- Allocate working buffers ----
- float *hidden = calloc(HIDDEN_DIM, sizeof(float));
- float *logits = calloc(VOCAB_SIZE, sizeof(float));
+ float *hidden = calloc(cfg.hidden_dim, sizeof(float));
+ float *logits = calloc(cfg.vocab_size, sizeof(float));
uint16_t *final_norm_w = get_tensor_ptr(wf, "model.norm.weight");
// ---- Serve mode: enter HTTP server loop (never returns) ----
@@ -6880,10 +7209,10 @@ int main(int argc, char **argv) {
// embed_lookup with GPU work, and enables the optimized prefill loop below.
float *embed_batch = NULL;
if (pt->count > 1) {
- embed_batch = malloc((size_t)pt->count * HIDDEN_DIM * sizeof(float));
+ embed_batch = malloc((size_t)pt->count * cfg.hidden_dim * sizeof(float));
double t_embed = now_ms();
for (int i = 0; i < pt->count; i++) {
- embed_lookup(wf, pt->ids[i], embed_batch + (size_t)i * HIDDEN_DIM);
+ embed_lookup(wf, pt->ids[i], embed_batch + (size_t)i * cfg.hidden_dim);
}
double embed_ms = now_ms() - t_embed;
printf(" [prefill] batch embed %d tokens: %.1f ms\n", pt->count, embed_ms);
@@ -6905,12 +7234,12 @@ int main(int argc, char **argv) {
// Load pre-embedded token from batch buffer
cache_telemetry_note_token();
- memcpy(hidden, embed_batch + (size_t)token_idx * HIDDEN_DIM,
- HIDDEN_DIM * sizeof(float));
+ memcpy(hidden, embed_batch + (size_t)token_idx * cfg.hidden_dim,
+ cfg.hidden_dim * sizeof(float));
// Run through all 60 transformer layers
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6941,14 +7270,14 @@ int main(int argc, char **argv) {
{
cache_telemetry_note_token();
if (embed_batch) {
- memcpy(hidden, embed_batch + (size_t)(pt->count - 1) * HIDDEN_DIM,
- HIDDEN_DIM * sizeof(float));
+ memcpy(hidden, embed_batch + (size_t)(pt->count - 1) * cfg.hidden_dim,
+ cfg.hidden_dim * sizeof(float));
} else {
embed_lookup(wf, pt->ids[0], hidden);
}
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -6965,9 +7294,9 @@ int main(int argc, char **argv) {
// ---- Final norm ----
if (final_norm_w) {
- float *normed = malloc(HIDDEN_DIM * sizeof(float));
- cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS);
- memcpy(hidden, normed, HIDDEN_DIM * sizeof(float));
+ float *normed = malloc(cfg.hidden_dim * sizeof(float));
+ cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps);
+ memcpy(hidden, normed, cfg.hidden_dim * sizeof(float));
free(normed);
}
@@ -6977,7 +7306,7 @@ int main(int argc, char **argv) {
double lm_ms = now_ms() - t_lm;
// ---- Sample first token ----
- int next_token = cpu_argmax(logits, VOCAB_SIZE);
+ int next_token = cpu_argmax(logits, cfg.vocab_size);
double ttft_ms = now_ms() - t0;
// Debug: show top-5 logits for first token
@@ -6985,7 +7314,7 @@ int main(int argc, char **argv) {
// Find top 5 manually
int top5[5] = {0,0,0,0,0};
float topv[5] = {-1e30f,-1e30f,-1e30f,-1e30f,-1e30f};
- for (int i = 0; i < VOCAB_SIZE; i++) {
+ for (int i = 0; i < cfg.vocab_size; i++) {
int min_k = 0;
for (int k = 1; k < 5; k++) if (topv[k] < topv[min_k]) min_k = k;
if (logits[i] > topv[min_k]) { topv[min_k] = logits[i]; top5[min_k] = i; }
@@ -6996,7 +7325,7 @@ int main(int argc, char **argv) {
top5[i], decode_token(vocab, top5[i]), topv[i]);
}
fprintf(stderr, "[debug] hidden rms after final_norm=%.4f, logits rms=%.4f\n",
- vec_rms(hidden, HIDDEN_DIM), vec_rms(logits, VOCAB_SIZE));
+ vec_rms(hidden, cfg.hidden_dim), vec_rms(logits, cfg.vocab_size));
}
printf("[ttft] %.0f ms (prefill %d tokens + lm_head %.0f ms)\n",
ttft_ms, pt->count, lm_ms);
@@ -7006,7 +7335,7 @@ int main(int argc, char **argv) {
fflush(stdout);
int total_generated = 1;
- int in_think = (next_token == THINK_START_TOKEN) ? 1 : 0;
+ int in_think = (next_token == cfg.think_start_token) ? 1 : 0;
int think_tokens = 0;
// ---- Auto-regressive generation ----
@@ -7019,23 +7348,23 @@ int main(int argc, char **argv) {
double t_gen_start = now_ms();
// Check EOS
- if (next_token == EOS_TOKEN_1 || next_token == EOS_TOKEN_2) {
+ if (next_token == cfg.eos_token_ids[0] || next_token == cfg.eos_token_ids[1]) {
fprintf(stderr, "\n[eos] Token %d at position %d\n", next_token, gen);
break;
}
// Think budget enforcement
- if (next_token == THINK_START_TOKEN) in_think = 1;
- if (next_token == THINK_END_TOKEN) in_think = 0;
+ if (next_token == cfg.think_start_token) in_think = 1;
+ if (next_token == cfg.think_end_token) in_think = 0;
if (in_think) think_tokens++;
// Embed the just-generated token (next iteration)
cache_telemetry_note_token();
embed_lookup(wf, next_token, hidden);
- // Run 60 layers (fused: 1+K cmd buffers per layer)
- for (int layer = 0; layer < NUM_LAYERS; layer++) {
- int is_full = ((layer + 1) % FULL_ATTN_INTERVAL == 0);
+ // Run 40 layers (fused: 1+K cmd buffers per layer)
+ for (int layer = 0; layer < cfg.num_layers; layer++) {
+ int is_full = cfg.is_full_attn[layer];
fused_layer_forward(wf, layer, hidden,
is_full ? kv_caches[layer] : NULL,
is_full ? NULL : layer_states[layer],
@@ -7049,9 +7378,9 @@ int main(int argc, char **argv) {
// Final norm
if (final_norm_w) {
- float *normed = malloc(HIDDEN_DIM * sizeof(float));
- cpu_rms_norm(hidden, final_norm_w, normed, HIDDEN_DIM, RMS_NORM_EPS);
- memcpy(hidden, normed, HIDDEN_DIM * sizeof(float));
+ float *normed = malloc(cfg.hidden_dim * sizeof(float));
+ cpu_rms_norm(hidden, final_norm_w, normed, cfg.hidden_dim, cfg.rms_norm_eps);
+ memcpy(hidden, normed, cfg.hidden_dim * sizeof(float));
free(normed);
}
@@ -7059,11 +7388,11 @@ int main(int argc, char **argv) {
lm_head_forward(wf, hidden, logits);
// Greedy sample
- next_token = cpu_argmax(logits, VOCAB_SIZE);
+ next_token = cpu_argmax(logits, cfg.vocab_size);
// Think budget: force end thinking if over budget
if (in_think && g_think_budget > 0 && think_tokens >= g_think_budget) {
- next_token = THINK_END_TOKEN;
+ next_token = cfg.think_end_token;
in_think = 0;
}
total_generated++;
@@ -7091,7 +7420,7 @@ int main(int argc, char **argv) {
printf("Generation: %.1f s (%.2f tok/s)\n",
gen_time / 1000.0, (total_generated - 1) * 1000.0 / gen_time);
}
- printf("Config: K=%d experts, %d layers\n", K, NUM_LAYERS);
+ printf("Config: K=%d experts, %d layers\n", K, cfg.num_layers);
if (g_expert_cache) {
uint64_t total = g_expert_cache->hits + g_expert_cache->misses;
printf("Expert cache: %llu hits, %llu misses (%.1f%% hit rate), %d/%d entries used\n",
@@ -7133,7 +7462,7 @@ int main(int argc, char **argv) {
expert_cache_free(g_expert_cache);
g_expert_cache = NULL;
}
- for (int i = 0; i < NUM_LAYERS; i++) {
+ for (int i = 0; i < cfg.num_layers; i++) {
if (kv_caches[i]) kv_cache_free(kv_caches[i]);
if (layer_states[i]) linear_attn_state_free(layer_states[i]);
if (layer_mmaps[i] != MAP_FAILED) munmap(layer_mmaps[i], layer_mmap_sizes[i]);
diff --git a/model_manager.py b/model_manager.py
new file mode 100755
index 0000000..7a2382a
--- /dev/null
+++ b/model_manager.py
@@ -0,0 +1,398 @@
+#!/usr/bin/env python3
+"""Flash-MoE Model Manager — list, search, and download compatible models.
+
+Compatible models: Qwen3.5 MoE with MLX quantization (model_type: qwen3_5_moe).
+These include GatedDeltaNet linear attention + full attention layers with
+switch_mlp expert routing.
+
+Usage:
+ python model_manager.py # List local + search remote
+ python model_manager.py --local # List local models only
+ python model_manager.py --search # Search HuggingFace for compatible models
+ python model_manager.py --download # Download a specific model
+ python model_manager.py --check # Check if a local model is compatible
+"""
+
+import argparse
+import json
+import os
+import struct
+import subprocess
+import sys
+from pathlib import Path
+
+try:
+ import requests
+except ImportError:
+ requests = None
+
+HF_CACHE = Path(os.path.expanduser("~/.cache/huggingface/hub"))
+HF_API = "https://huggingface.co/api"
+
+# Known compatible model types
+COMPATIBLE_MODEL_TYPES = {"qwen3_5_moe"}
+
+# Search queries — MLX-quantized Qwen3.5 models
+SEARCH_QUERIES = [
+ "mlx-community Qwen3.5",
+ "mlx Qwen3.5 MoE",
+ "lmstudio-community Qwen3.5 MLX",
+]
+
+# MoE model name patterns: "35B-A3B", "122B-A10B", "397B-A17B" etc.
+# The "-AB" suffix indicates active parameters = MoE architecture
+import re
+MOE_PATTERN = re.compile(r'\d+B-A\d+B')
+
+
+def find_config_json(model_path: Path) -> Path | None:
+ """Find config.json in a model directory, handling HF cache layout."""
+ direct = model_path / "config.json"
+ if direct.exists():
+ return direct
+ snapshots = model_path / "snapshots"
+ if snapshots.exists():
+ for snap in sorted(snapshots.iterdir(), reverse=True):
+ candidate = snap / "config.json"
+ if candidate.exists():
+ return candidate
+ return None
+
+
+def check_compatibility(model_path: Path) -> dict:
+ """Check if a local model is compatible with Flash-MoE.
+
+ Returns a dict with:
+ compatible: bool
+ reason: str (if not compatible)
+ info: dict (model details if compatible)
+ """
+ config_path = find_config_json(model_path)
+ if not config_path:
+ return {"compatible": False, "reason": "No config.json found"}
+
+ with open(config_path) as f:
+ config = json.load(f)
+
+ model_type = config.get("model_type", "")
+ if model_type not in COMPATIBLE_MODEL_TYPES:
+ return {
+ "compatible": False,
+ "reason": f"Incompatible model_type: {model_type} (need: {', '.join(COMPATIBLE_MODEL_TYPES)})",
+ }
+
+ tc = config.get("text_config", {})
+ if not tc:
+ return {"compatible": False, "reason": "Missing text_config in config.json"}
+
+ # Check for required fields
+ required = [
+ "hidden_size", "num_hidden_layers", "num_experts",
+ "num_experts_per_tok", "moe_intermediate_size",
+ ]
+ missing = [k for k in required if k not in tc]
+ if missing:
+ return {"compatible": False, "reason": f"Missing fields: {', '.join(missing)}"}
+
+ # Check quantization
+ qc = config.get("quantization_config", config.get("quantization", {}))
+ bits = qc.get("bits", "?")
+ group_size = qc.get("group_size", "?")
+
+ # Check for packed experts
+ model_dir = config_path.parent
+ has_packed = (model_dir / "packed_experts").exists() or any(
+ (model_dir.parent / "packed_experts").exists()
+ for _ in [None]
+ )
+
+ # Check for extracted weights
+ # Look relative to cwd (where infer runs from)
+ has_weights = Path("metal_infer/model_weights.bin").exists() or Path("model_weights.bin").exists()
+
+ info = {
+ "model_type": model_type,
+ "hidden_size": tc.get("hidden_size"),
+ "num_layers": tc.get("num_hidden_layers"),
+ "num_experts": tc.get("num_experts"),
+ "experts_per_tok": tc.get("num_experts_per_tok"),
+ "moe_intermediate": tc.get("moe_intermediate_size"),
+ "vocab_size": tc.get("vocab_size"),
+ "bits": bits,
+ "group_size": group_size,
+ "has_packed_experts": has_packed,
+ "has_extracted_weights": has_weights,
+ "config_path": str(config_path),
+ }
+
+ # Estimate sizes
+ ne = tc.get("num_experts", 0)
+ nl = tc.get("num_hidden_layers", 0)
+ mid = tc.get("moe_intermediate_size", 0)
+ hid = tc.get("hidden_size", 0)
+ if isinstance(bits, int) and bits > 0:
+ vals_per_u32 = 32 // bits
+ expert_bytes = 0
+ # gate + up: [mid, hid]
+ for _ in range(2):
+ w = mid * ((hid + vals_per_u32 - 1) // vals_per_u32) * 4
+ s = mid * ((hid + group_size - 1) // group_size) * 2
+ expert_bytes += w + s + s # weight + scales + biases
+ # down: [hid, mid]
+ w = hid * ((mid + vals_per_u32 - 1) // vals_per_u32) * 4
+ s = hid * ((mid + group_size - 1) // group_size) * 2
+ expert_bytes += w + s + s
+
+ total_expert_gb = ne * nl * expert_bytes / (1024**3)
+ active_per_token_mb = tc.get("num_experts_per_tok", 0) * expert_bytes / (1024**2)
+ info["expert_size_bytes"] = expert_bytes
+ info["total_expert_disk_gb"] = round(total_expert_gb, 1)
+ info["active_per_token_mb"] = round(active_per_token_mb, 1)
+
+ # Count total params (rough estimate)
+ total_params_b = ne * nl * mid * hid * 3 * 2 / 1e9 # gate+up+down, *2 for bidir
+ info["approx_total_params"] = f"~{total_params_b:.0f}B" if total_params_b > 1 else f"~{total_params_b*1000:.0f}M"
+
+ return {"compatible": True, "info": info}
+
+
+def list_local_models():
+ """List locally cached HuggingFace models and check compatibility."""
+ if not HF_CACHE.exists():
+ print("No HuggingFace cache found at", HF_CACHE)
+ return []
+
+ models = []
+ for entry in sorted(HF_CACHE.iterdir()):
+ if not entry.name.startswith("models--"):
+ continue
+ # Convert models--org--name to org/name
+ parts = entry.name.split("--", 2)
+ if len(parts) >= 3:
+ repo_id = f"{parts[1]}/{parts[2]}"
+ else:
+ repo_id = entry.name
+
+ result = check_compatibility(entry)
+ result["repo_id"] = repo_id
+ result["path"] = str(entry)
+ models.append(result)
+
+ return models
+
+
+def search_remote_models():
+ """Search HuggingFace for compatible Qwen3.5 MoE models."""
+ if not requests:
+ print("Install 'requests' to search HuggingFace: pip install requests")
+ return []
+
+ seen = set()
+ results = []
+
+ for query in SEARCH_QUERIES:
+ try:
+ resp = requests.get(
+ f"{HF_API}/models",
+ params={
+ "search": query,
+ "limit": 30,
+ "sort": "downloads",
+ "direction": -1,
+ },
+ timeout=10,
+ )
+ resp.raise_for_status()
+ for model in resp.json():
+ repo_id = model.get("id", "")
+ if repo_id in seen:
+ continue
+ seen.add(repo_id)
+
+ # Filter: must have "qwen" and "moe" or "3.5" indicators
+ lower = repo_id.lower()
+ tags = [t.lower() for t in model.get("tags", [])]
+
+ is_qwen35 = "qwen3.5" in lower or "qwen3_5" in lower
+ is_mlx = "mlx" in lower or "mlx" in " ".join(tags)
+ is_moe = bool(MOE_PATTERN.search(repo_id))
+
+ # We need: Qwen3.5 + MLX quantized + MoE architecture
+ if is_qwen35 and is_mlx and is_moe:
+ # Extract quant info from name
+ quant = ""
+ for q in ["3bit", "4bit", "6bit", "8bit"]:
+ if q in lower:
+ quant = q
+ break
+
+ results.append({
+ "repo_id": repo_id,
+ "downloads": model.get("downloads", 0),
+ "likes": model.get("likes", 0),
+ "quant": quant,
+ "last_modified": model.get("lastModified", "")[:10],
+ })
+ except Exception as e:
+ print(f"Warning: search failed for '{query}': {e}", file=sys.stderr)
+
+ return results
+
+
+def download_model(repo_id: str):
+ """Download a model from HuggingFace."""
+ # Try huggingface-cli first
+ hf_cli = None
+ for cmd in ["huggingface-cli", "hf"]:
+ try:
+ subprocess.run([cmd, "--help"], capture_output=True, check=True)
+ hf_cli = cmd
+ break
+ except (FileNotFoundError, subprocess.CalledProcessError):
+ continue
+
+ if hf_cli:
+ print(f"Downloading {repo_id} via {hf_cli}...")
+ subprocess.run([hf_cli, "download", repo_id], check=True)
+ else:
+ # Fall back to Python
+ try:
+ from huggingface_hub import snapshot_download
+ print(f"Downloading {repo_id} via huggingface_hub...")
+ path = snapshot_download(repo_id)
+ print(f"Downloaded to: {path}")
+ except ImportError:
+ print("ERROR: No download tool available.")
+ print("Install one of:")
+ print(" pip install huggingface-hub # Python library")
+ print(" pip install huggingface-cli # CLI tool")
+ print(f"\nOr manually: git clone https://huggingface.co/{repo_id}")
+ sys.exit(1)
+
+
+def format_size(gb: float) -> str:
+ if gb >= 1:
+ return f"{gb:.1f} GB"
+ return f"{gb * 1024:.0f} MB"
+
+
+def print_model_info(info: dict, indent: str = " "):
+ """Print formatted model info."""
+ print(f"{indent}Architecture: {info['num_layers']} layers, "
+ f"hidden={info['hidden_size']}, "
+ f"{info['num_experts']} experts (K={info['experts_per_tok']})")
+ print(f"{indent}Quantization: {info['bits']}-bit, group_size={info['group_size']}")
+ if "total_expert_disk_gb" in info:
+ print(f"{indent}Expert data: {format_size(info['total_expert_disk_gb'])} on disk, "
+ f"{info['active_per_token_mb']:.1f} MB active/token")
+ if "approx_total_params" in info:
+ print(f"{indent}Parameters: {info['approx_total_params']} total")
+
+ # Readiness indicators
+ ready = True
+ if not info.get("has_packed_experts"):
+ print(f"{indent}Packed experts: NOT FOUND (run repack_experts.py)")
+ ready = False
+ else:
+ print(f"{indent}Packed experts: OK")
+ if not info.get("has_extracted_weights"):
+ print(f"{indent}Weights file: NOT FOUND (run extract_weights.py)")
+ ready = False
+ else:
+ print(f"{indent}Weights file: OK")
+
+ if ready:
+ print(f"{indent}Status: READY TO RUN")
+ else:
+ print(f"{indent}Status: NEEDS PREPARATION (see above)")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Flash-MoE Model Manager — list, search, and download compatible models"
+ )
+ parser.add_argument("--local", action="store_true", help="List local models only")
+ parser.add_argument("--search", action="store_true", help="Search HuggingFace only")
+ parser.add_argument("--download", type=str, metavar="REPO", help="Download a model (e.g. mlx-community/Qwen3.5-35B-A3B-4bit)")
+ parser.add_argument("--check", type=str, metavar="PATH", help="Check if a local model is compatible")
+ args = parser.parse_args()
+
+ if args.check:
+ path = Path(args.check).expanduser()
+ result = check_compatibility(path)
+ if result["compatible"]:
+ print(f"COMPATIBLE: {path}")
+ print_model_info(result["info"])
+ else:
+ print(f"NOT COMPATIBLE: {result['reason']}")
+ return
+
+ if args.download:
+ download_model(args.download)
+ # Check compatibility after download
+ cache_name = "models--" + args.download.replace("/", "--")
+ cache_path = HF_CACHE / cache_name
+ if cache_path.exists():
+ result = check_compatibility(cache_path)
+ if result["compatible"]:
+ print(f"\nModel is compatible with Flash-MoE!")
+ print_model_info(result["info"])
+ print(f"\nNext steps:")
+ print(f" 1. python repack_experts.py --model {cache_path}")
+ print(f" 2. python metal_infer/extract_weights.py --model {cache_path}")
+ print(f" 3. ./metal_infer/infer --model {cache_path} --prompt 'Hello' --tokens 20")
+ return
+
+ # Default: show both local and remote
+ show_local = not args.search
+ show_remote = not args.local
+
+ if show_local:
+ print("=" * 60)
+ print("LOCAL MODELS")
+ print("=" * 60)
+ models = list_local_models()
+ if not models:
+ print(" No models found in", HF_CACHE)
+ else:
+ compatible_count = 0
+ for m in models:
+ if m["compatible"]:
+ compatible_count += 1
+ print(f"\n {m['repo_id']}")
+ print_model_info(m["info"], indent=" ")
+ print(f" Path: {m['path']}")
+ else:
+ print(f"\n {m['repo_id']} (incompatible: {m.get('reason', 'unknown')})")
+ print(f"\n {compatible_count}/{len(models)} compatible models found")
+
+ if show_remote:
+ print()
+ print("=" * 60)
+ print("AVAILABLE ON HUGGINGFACE")
+ print("=" * 60)
+ remote = search_remote_models()
+ if not remote:
+ print(" No compatible models found (or search failed)")
+ else:
+ # Mark which ones are already local
+ local_repos = set()
+ if HF_CACHE.exists():
+ for entry in HF_CACHE.iterdir():
+ if entry.name.startswith("models--"):
+ parts = entry.name.split("--", 2)
+ if len(parts) >= 3:
+ local_repos.add(f"{parts[1]}/{parts[2]}")
+
+ for m in remote:
+ local_tag = " [LOCAL]" if m["repo_id"] in local_repos else ""
+ quant_tag = f" [{m['quant']}]" if m.get("quant") else ""
+ print(f"\n {m['repo_id']}{local_tag}{quant_tag}")
+ print(f" Downloads: {m['downloads']:,} Likes: {m['likes']}")
+
+ print(f"\n {len(remote)} models found")
+ print(f"\n Download with: python model_manager.py --download ")
+
+
+if __name__ == "__main__":
+ main()