diff --git a/README.md b/README.md index 4b7c69ec..b909f58c 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,8 @@ Q4 requires the larger-memory machine class, so M3 Max Q4 numbers are `N/A`. | MacBook Pro M3 Max, 128 GB | q2 | 11709 tokens | 250.11 t/s | 21.47 t/s | | MacBook Pro M3 Max, 128 GB | q4 | short | N/A | N/A | | MacBook Pro M3 Max, 128 GB | q4 | long | N/A | N/A | +| MacBook Pro M5 Max, 128 GB | q2 | short | 87.25 t/s | 34.27 t/s | +| MacBook Pro M5 Max, 128 GB | q2 | 11707 tokens | 463.44 t/s | 25.90 t/s | | Mac Studio M3 Ultra, 512 GB | q2 | short | 84.43 t/s | 36.86 t/s | | Mac Studio M3 Ultra, 512 GB | q2 | 11709 tokens | 468.03 t/s | 27.39 t/s | | Mac Studio M3 Ultra, 512 GB | q4 | short | 78.95 t/s | 35.50 t/s | @@ -194,6 +196,228 @@ exponential sweeps. Output is CSV with one row per frontier: latest prefill interval tokens/sec, generation tokens/sec at that frontier, and `kvcache_bytes`. +Sessions prefill long prompts in 4096-token chunks by default. Set +`DS4_METAL_PREFILL_CHUNK=N` to compare another chunk size, for example `2048` +to reduce transient memory, or `DS4_METAL_PREFILL_CHUNK=0` to prefill a prompt +as one whole batch when memory allows. Changing the chunk changes the KV +checkpoint shape, so compare it as an explicit run configuration. +Chunked Metal prefill reuses the same range-capable layer-major graph for each +chunk, preserving absolute compressor/indexer boundaries while avoiding the old +per-layer chunk dispatch path. + +## Metal 4 and M5 Neural Accelerators + +The current production path is still hand-written Metal compute kernels over +`MTLBuffer` storage. That is intentional: DS4's hot path is dominated by +quantized routed-MoE matvec/matmul, sparse compressed attention, and mmap-backed +model views, which do not map cleanly to a whole-model Core ML package. + +Metal 4 is the right next target, but it should be introduced as a feature-gated +kernel backend rather than a rewrite. On macOS 26+ with `MTLGPUFamilyMetal4`, +Apple exposes tensor resources and Metal 4 command infrastructure that can run +machine-learning work on the same GPU timeline as compute work. On M5 hardware, +Apple describes the per-GPU-core Neural Accelerators as available to developers +through the Metal 4 Tensor APIs. `DS4_METAL_MEMORY_REPORT=1` now reports the +device, Metal 4 family support, MTL4 queue availability, and whether the device +looks like an M5 Neural Accelerator target. + +The implementation follows the same conservative shape used by llama.cpp's +current Metal backend: the tensor API is disabled by default on pre-M5/pre-A19 +devices, can be forced with `DS4_METAL_TENSOR_ENABLE=1`, and can always be +disabled with `DS4_METAL_TENSOR_DISABLE=1`. At startup ds4 compiles a tiny +Metal Performance Primitives tensor matmul probe before it lets the main Metal +shader source see `DS4_METAL_HAS_TENSOR`, so unsupported SDK/device +combinations fall back to the legacy kernels. + +Metal Tensor policy is explicit and guarded. Use `-mt auto` or `--mt auto` for +the default route policy, `-mt on` to force Tensor routes where the Metal tensor +path is available, and `-mt off` for the legacy Metal reference path. The old +`--mpp` spelling remains accepted as a compatibility alias. Auto currently +keeps attention-output Tensor in the validated late-layer window, keeps Q8_0 +prefill in the lower-drift conservative layer window, and runs routed-MoE Tensor +only in its conservative layer window while preserving +same-top1/same-greedy agreement. Unguarded Q8_0, attention-output all-layer, +and all-layer routed-MoE Tensor routes remain +opt-in diagnostics. The environment controls +`DS4_METAL_MPP_ENABLE` and `DS4_METAL_MPP_DISABLE` accept `1/true/yes/on` and +`0/false/no/off`; `DS4_METAL_MPP_ENABLE=0` disables Tensor routes instead of +enabling them by mere presence. Passing `--quality` also disables Tensor routes +so strict/debug runs stay on the legacy Metal kernels. Set +`DS4_METAL_MPP_FAST=1` to opt into the current same-top1/same-greedy fast +profile: it widens Q8_0 and attention-output Tensor to all layers while keeping +the routed-MoE all-layer diagnostic window. This profile is not the default because its +top-k overlap is weaker than auto in the current full-model suite. +The default safe-window policy uses the direct-RHS tensor layout for Tensor +routes; set `DS4_METAL_MPP_DIRECT_RHS=0` to compare against the older staged-RHS +layout. Q8_0 and attention-output direct-RHS routes support both 32-token and +64-token Tensor tiles. Auto defaults attention-output to 64-token tiles, while +Q8_0 uses 64-token tiles below 4096-token batches and 32-token tiles for larger +prompt batches on M5. Set `DS4_METAL_MPP_Q8_0_TILE_N=32` or +`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to force the narrower layout. The +route-specific `DS4_METAL_MPP_Q8_0_DIRECT_RHS=1`, +`DS4_METAL_MPP_F16_DIRECT_RHS=1`, and +`DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS=1` switches isolate that layout without +turning on every direct-RHS route at once when the global +`DS4_METAL_MPP_DIRECT_RHS=0` override is set. + +The Q8_0 prefill Tensor route can be isolated with +`DS4_METAL_MPP_Q8_0_ENABLE=1` or `DS4_METAL_MPP_Q8_0_DISABLE=1`. It only +affects prompt batches larger than eight tokens. **On M5 the Q8_0 Tensor +route is default-off**: bisection on M5 Max showed it was the sole source +of the M5-only `-mt auto` vs `-mt off` logit drift while the other Tensor +routes (F16 compressor, attention-output, MoE) stayed bit-clean on short +prompts. Set `DS4_METAL_MPP_Q8_0_ENABLE=1` to opt back in. On non-M5 +devices Q8_0 stays default-on and uses the late full-model-safe layer +window 38..42 plus `attn_q_b` in layers 32..37 for all prompt batch +sizes. It +uses 64-token tiles below 4096-token batches and 32-token tiles for larger +prompt batches on M5, accepts partial token tails, and falls back to the legacy +kernel when the Metal 4 tensor path is unavailable. When macOS reports Low +Power Mode, auto widens Q8_0 prefill to all Q8_0 contexts because that profile +improves both prefill and generation speed in current M5 Max low-power sweeps. +Set `DS4_METAL_MPP_LOW_POWER_DISABLE=1` to keep the normal guarded Q8_0 +profile, or `DS4_METAL_MPP_LOW_POWER_ENABLE=1` to force the low-power profile +for comparison. +Set `DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=0` to force the old partial-tail +fallback while debugging. Set `DS4_METAL_MPP_Q8_0_FILTER=all` to reproduce the +wider all-context Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=attn_q_b` to reproduce +the broader small-prompt speed profile, or +`DS4_METAL_MPP_Q8_0_FILTER=` to force named +full-graph Q8 modules such as `attn_q_a`, `attn_kv`, `attn_q_b`, `attn_out`, +`shared_gate`, `shared_up`, or `shared_down`. Use +`@layer=A..B` to test one module family only in a layer window, for +example `shared_up@layer=30..37`. Set `DS4_METAL_MPP_Q8_0_TILE_N=32` to +compare against the narrower Tensor token tile. The isolated +`./ds4_test --metal-kernels` regression reports small/medium/model-ish kernel +deltas; the full-model +`./ds4_test --metal-mpp-equivalence` diagnostic compares default auto against +`-mt off`. Set `DS4_TEST_MPP_EQ_FORCE_ON=1` to compare forced Tensor against +`-mt off` while working on a route. `DS4_TEST_MPP_EQ_CASE=` +limits the diagnostic to one prompt, and `DS4_TEST_MPP_EQ_MATRIX=1` prints +separate auto, fast-profile, Q8-only, attention-output-only, MoE gate/up/down-only, +and full-forced summary rows. The equivalence gate requires finite logits, the +same top-1 token, and matching greedy continuation; it also reports top-5/top-20 +overlap, top-20 rank displacement, top-20 logit deltas, and whole-vocab RMS/max +drift so route changes can be judged beyond pass/fail. + +Full-graph route localization is available with +`DS4_METAL_MPP_COMPARE_ROUTE=q8|attn_out|moe_gate|moe_up|moe_down` and optional +`DS4_METAL_MPP_COMPARE_MAX=N`. The comparator snapshots the candidate Tensor +output, runs the legacy Metal route on the same tensor input, and reports the +first comparison that exceeds the kernel target, including module/layer context, +shape, max absolute error, RMS, and the largest element deltas. Set +`DS4_METAL_MPP_COMPARE_VERBOSE=1` to print passing comparisons as well. + +Current Tensor route status balances drift with prefill throughput: `auto` enables +F16 compressor, attention-output low projection, and routed-MoE Tensor. The +Q8_0 prefill Tensor route is enabled by default on pre-M5 devices and +**default-off on M5**, where bisection traced the entire `-mt auto` vs +`-mt off` drift to that single route; opt back in with +`DS4_METAL_MPP_Q8_0_ENABLE=1`. Attention-output low projection uses layers +32..42 by default, Q8_0 (when enabled) uses the narrower `attn_q_b` 32..37 +plus all-Q8 38..42 window by default, and routed-MoE Tensor uses the +lower-drift conservative default window: gate/up from layer 20 and down +from layer 22. This gives up some of the all-layer prefill speedup to +avoid the larger drift seen with the previous broader Q8_0 and layer-0 +routed-MoE Tensor windows. The current auto suite on M5 reports +same-top1/same-greedy agreement on all five fixtures with minimum top-5 +overlap `5/5`, minimum top-20 overlap `20/20`, `worst_rms ~= 0.169`, and +`worst_top20_max_abs ~= 0.306` (three short fixtures are bit-exact; +residual drift is concentrated on the two long-context fixtures and +comes from the still-enabled F16/attn-out/MoE Tensor routes compounding +through 43 layers). The Q8_0 and attention-output low Tensor +kernels stage activation tiles through half to match the legacy Metal matmul +input path, which brings the isolated model-ish Q8_0 regression under the +strict kernel target and removes the first attention-output comparator breach. +Most Q8_0 projection families stay restricted to layers 38..42 because earlier +layers can amplify small local differences through normalization/attention. The +broader `attn_q_b` profile remains available through the filter knob when +prefill speed is more important than logit drift. The current auto policy also +uses Q8_0 partial tails, direct-RHS Tensor inputs, dynamic Q8_0 tile width, and +64-token tiles for attention-output low projections. In a quick local M5 Max +512-token sanity row, this lower-drift auto profile sampled `339.36` prompt +tokens/sec and `32.97` generation tokens/sec, versus `264.09` and `32.62` for +`--quality`; full sweeps still show visible desktop-load variance. The F16 +compressor route did not introduce measurable drift in the current prompt set. + +The `DS4_METAL_MPP_FAST=1` profile is the measured high-throughput diagnostic +profile under the relaxed same-top1/same-greedy gate. In the current prompt +suite it keeps top-1 and greedy continuations stable, but reports weaker top-k +overlap than auto (`worst_rms ~= 0.951`, `worst_top20_max_abs ~= 4.03`, +minimum top-20 overlap `16/20`). It remains diagnostic-only because it widens +the Q8_0, attention-output, and routed-MoE route windows that produce the +largest full-suite drift. + +The routed-MoE Tensor projections are enabled by default from layer 20 for +gate/up and layer 22 for down. For route isolation, use +`DS4_METAL_MPP_MOE_GATE_ENABLE/DISABLE`, +`DS4_METAL_MPP_MOE_UP_ENABLE/DISABLE`, and +`DS4_METAL_MPP_MOE_DOWN_ENABLE/DISABLE`; `DS4_METAL_MPP_MOE_DISABLE=1` +disables all routed-MoE Tensor projections. Set the common +`DS4_METAL_MPP_MOE_FILTER` or route-specific +`DS4_METAL_MPP_MOE_GATE_FILTER`, `DS4_METAL_MPP_MOE_UP_FILTER`, and +`DS4_METAL_MPP_MOE_DOWN_FILTER` to `all`, `late_safe`, `none`, or +comma-separated full-graph context substrings to localize safe layer windows. +Use `layer=N` for an exact layer match or `layer=A..B` for an inclusive layer +range when testing sparse Tensor windows. The same `@layer=A..B` +syntax can restrict a context substring to a layer window. +Set `DS4_METAL_MPP_MOE_TILE_N=64` to test the experimental wider routed-MoE +Tensor token tile for performance against the default `32`. The routed-MoE Tensor +path uses the faster first-PR threadgroup tensor layout by default inside the +active routed-MoE windows; set `DS4_METAL_MPP_MOE_FAST_LAYOUT=0` to compare +against the newer staged layout. Set +`DS4_METAL_MPP_MOE_START_LAYER=N`, or the route-specific +`DS4_METAL_MPP_MOE_GATE_START_LAYER`, +`DS4_METAL_MPP_MOE_UP_START_LAYER`, and +`DS4_METAL_MPP_MOE_DOWN_START_LAYER`, to test routed-MoE Tensor start layers; the +resolved start layer also defines the route's default `late_safe` filter. Set +`DS4_METAL_MPP_MOE_PAIR_GATE_UP=1` only to profile the experimental fused +gate/up Tensor dispatch; it passes the current equivalence gate but is not a +default path because it is slower than separate gate and up dispatches. + +For the common six-routed-expert prefill shape, the down-projection expert +outputs are summed with a single Metal kernel instead of five chained add +passes. Set `DS4_METAL_MOE_SUM6_DISABLE=1` to compare or temporarily disable +that fused sum route. + +Long-context decode uses the indexed mixed-attention kernel once ratio-4 +compressed rows exceed the dense-attention window. The default decode +specialization stages sixteen selected rows per threadgroup block; set +`DS4_METAL_INDEXED_ATTN_RB4=1` to compare the older four-row staging variant. +Set `DS4_METAL_DECODE_INDEXER_TOP_K=64`, `128`, `256`, or `512` to cap the +decode indexer candidate count for speed/quality diagnostics. The normal +non-quality decode path keeps the legacy dense-attention window until there are +more than `1024` compressed rows, then selects `256` rows in sparse indexed +attention. Set `DS4_METAL_DECODE_INDEXER_SPARSE_THRESHOLD` to `64`, `128`, +`256`, `512`, `1024`, `2048`, or `4096` to tune the sparse-decode crossover +separately. `--quality` keeps the full `512` candidate path unless this +environment override is set explicitly. + +The attention-output low-projection Tensor route applies to full 32-token multiples +in the default safe window, using a 64-token Tensor tile by default and falling +back to the existing indexed simdgroup kernel for shorter or non-32-multiple +tails. Attention-output Tensor is limited to the measured full-model-safe layer +window 32..42 by default. Set +`DS4_METAL_MPP_ATTN_OUT_ENABLE=1` or `DS4_METAL_MPP_ATTN_OUT_DISABLE=1` to +isolate this route. Set `DS4_METAL_MPP_ATTN_OUT_FILTER=all`, `late_safe`, +`none`, or a comma-separated list of full-graph context substrings such as +`layer=42` to localize full-model-safe layer windows. Layer filters are exact, +and `layer=A..B` matches an inclusive range. Set +`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to compare against the narrower Tensor token +tile. The all-layer +attention-output Tensor route still fails long-prompt full-model equivalence +despite per-layer low-projection differences below the current kernel target. +The ratio-2 F16 compressor route can similarly be controlled with +`DS4_METAL_MPP_F16_ENABLE=1` or `DS4_METAL_MPP_F16_DISABLE=1`. +`DS4_METAL_MPP_F16_PAIR=1` tests a paired KV/gate compressor dispatch that keeps +the standard simdgroup F16 matmul accumulation shape. It passes the current +full-model equivalence gate, but the measured long-code prefill change was +within noise (`~0.4%`), so it remains opt-in. `DS4_METAL_MPP_F16_WIDE=1` tests +wider 512/1024-column compressor Tensor, including the paired Tensor route when both +variables are set. The wide route is diagnostic only: the current long-code +prompt fails full-model equivalence with wide F16 Tensor (`rms ~= 0.569`, +`top20_max_abs ~= 1.48`), so it is not enabled by `auto`. + ## CLI One-shot prompt: @@ -705,6 +929,7 @@ All project tests are driven by the C runner: ```sh make test # ./ds4_test --all ./ds4_test --logprob-vectors +./ds4_test --metal-mpp-equivalence ./ds4_test --server ``` @@ -716,6 +941,8 @@ first answer: ```sh ./ds4 --dump-tokens -p "..." ./ds4 --dump-logprobs /tmp/out.json --logprobs-top-k 20 --temp 0 -p "..." +./ds4 --dump-logits /tmp/q2-off.json --metal -mt off --nothink --prompt-file prompt.txt +python3 speed-bench/compare_logit_drift.py /tmp/q2-off.json /tmp/q2-mt.json /tmp/q4-off.json --labels q2_mt q4_off ./ds4-server --trace /tmp/ds4-trace.txt ... ``` diff --git a/ds4.c b/ds4.c index 51410e33..0182acd2 100644 --- a/ds4.c +++ b/ds4.c @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -6110,8 +6111,8 @@ static uint32_t ds4_default_prefill_cap_for_prompt(int prompt_len) { if (v <= 0) return cap; cap = (uint32_t)v; } - } else if (prompt_len > 2048) { - cap = 2048u; + } else if (prompt_len > 4096) { + cap = 4096u; } if (cap == 0) cap = 1; @@ -8910,9 +8911,81 @@ static bool metal_graph_capture_prefix1_index_state(ds4_gpu_graph *g, uint32_t i g->layer_index_state_score[il], 0, bytes) != 0; } +static bool metal_graph_decode_indexer_top_k_override(uint32_t *value) { + static int parsed = -1; + static uint32_t cached = 0; + if (parsed >= 0) { + if (parsed > 0 && value) *value = cached; + return parsed > 0; + } + + parsed = 0; + const char *env = getenv("DS4_METAL_DECODE_INDEXER_TOP_K"); + if (env && env[0]) { + char *end = NULL; + unsigned long v = strtoul(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end != env && end && *end == '\0' && + (v == 64ul || v == 128ul || v == 256ul || v == 512ul) && + v <= DS4_N_INDEXER_TOP_K) { + cached = (uint32_t)v; + parsed = 1; + } else { + fprintf(stderr, + "ds4: invalid DS4_METAL_DECODE_INDEXER_TOP_K=%s; " + "expected 64, 128, 256, or 512\n", + env); + } + } + if (parsed > 0 && value) *value = cached; + return parsed > 0; +} + static uint32_t metal_graph_decode_indexer_top_k(const ds4_gpu_graph *g) { + uint32_t value = 0; + if (metal_graph_decode_indexer_top_k_override(&value)) return value; + + const uint32_t speed_default = + DS4_N_INDEXER_TOP_K < 256u ? DS4_N_INDEXER_TOP_K : 256u; + return (g && g->quality) ? DS4_N_INDEXER_TOP_K : speed_default; +} + +static uint32_t metal_graph_decode_indexer_sparse_threshold(const ds4_gpu_graph *g) { (void)g; - return DS4_N_INDEXER_TOP_K; + static int parsed = -1; + static uint32_t cached = 0; + if (parsed < 0) { + parsed = 0; + const char *env = getenv("DS4_METAL_DECODE_INDEXER_SPARSE_THRESHOLD"); + if (env && env[0]) { + char *end = NULL; + unsigned long v = strtoul(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end != env && end && *end == '\0' && + (v == 64ul || v == 128ul || v == 256ul || v == 512ul || + v == 1024ul || v == 2048ul || v == 4096ul)) { + cached = (uint32_t)v; + parsed = 1; + } else { + fprintf(stderr, + "ds4: invalid DS4_METAL_DECODE_INDEXER_SPARSE_THRESHOLD=%s; " + "expected 64, 128, 256, 512, 1024, 2048, or 4096\n", + env); + } + } + } + if (parsed > 0) return cached; + + uint32_t value = 0; + if (metal_graph_decode_indexer_top_k_override(&value)) return value; + + /* Keep dense attention longer than the legacy 512-row window by default. + * Around the 2K frontier the sparse path's score/top-k setup dominates + * the smaller attention scan, while larger contexts benefit from sparse + * indexed attention. The speed default + * selects fewer rows only after decode has enough compressed rows for the + * sparse indexed path to pay for its score/top-k overhead. */ + return 1024u; } /* ========================================================================= @@ -9387,7 +9460,9 @@ static bool metal_graph_encode_decode_layer( DS4_RMS_EPS) != 0; if (ok && emit) g->layer_n_index_comp[il]++; const uint32_t decode_top_k = metal_graph_decode_indexer_top_k(g); - if (ok && g->layer_n_comp[il] > decode_top_k) { + const uint32_t decode_sparse_threshold = + metal_graph_decode_indexer_sparse_threshold(g); + if (ok && g->layer_n_comp[il] > decode_sparse_threshold) { const uint64_t indexer_q_dim = (uint64_t)DS4_N_INDEXER_HEAD * DS4_N_INDEXER_HEAD_DIM; if (!layer->indexer_attn_q_b || layer->indexer_attn_q_b->type != DS4_TENSOR_F16 || @@ -9972,6 +10047,30 @@ static bool metal_graph_matmul_plain_tensor( return false; } +static bool metal_graph_matmul_q8_0_named_tensor( + const char *module, + uint32_t il, + uint32_t pos0, + ds4_gpu_tensor *out, + const ds4_model *model, + const ds4_tensor *w, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + ds4_gpu_set_mpp_compare_context(module, il, pos0); + const bool ok = ds4_gpu_matmul_q8_0_tensor(out, + model->map, + model->size, + w->abs_offset, + in_dim, + out_dim, + x, + n_tok) != 0; + ds4_gpu_clear_mpp_compare_context(); + return ok; +} + static bool metal_graph_encode_output_head_mtp( ds4_gpu_graph *g, const ds4_model *base_model, @@ -10970,6 +11069,66 @@ static bool metal_graph_q_stage_profile_boundary( return ds4_gpu_begin_commands() != 0; } +static bool ds4_env_bool_enabled(const char *name) { + const char *v = getenv(name); + if (!v) return false; + + while (isspace((unsigned char)*v)) v++; + size_t n = strlen(v); + while (n > 0 && isspace((unsigned char)v[n - 1])) n--; + if (n == 0) return true; + + if ((n == 1 && v[0] == '0') || + (n == 2 && strncasecmp(v, "no", n) == 0) || + (n == 3 && strncasecmp(v, "off", n) == 0) || + (n == 5 && strncasecmp(v, "false", n) == 0)) { + return false; + } + return true; +} + +static bool metal_graph_matmul_f16_pair_or_separate( + ds4_gpu_tensor *out_a, + ds4_gpu_tensor *out_b, + const ds4_model *model, + uint64_t weight_a_offset, + uint64_t weight_b_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tokens) { + if (ds4_env_bool_enabled("DS4_METAL_MPP_F16_PAIR")) { + if (ds4_gpu_matmul_f16_pair_tensor(out_a, + out_b, + model->map, + model->size, + weight_a_offset, + weight_b_offset, + in_dim, + out_dim, + x, + n_tokens) != 0) { + return true; + } + } + return ds4_gpu_matmul_f16_tensor(out_a, + model->map, + model->size, + weight_a_offset, + in_dim, + out_dim, + x, + n_tokens) != 0 && + ds4_gpu_matmul_f16_tensor(out_b, + model->map, + model->size, + weight_b_offset, + in_dim, + out_dim, + x, + n_tokens) != 0; +} + static bool metal_graph_encode_layer_attention_batch( ds4_gpu_graph *g, const ds4_model *model, @@ -11085,28 +11244,32 @@ static bool metal_graph_encode_layer_attention_batch( } DS4_METAL_PROFILE_ATTN_STAGE("norm"); DS4_METAL_PROFILE_Q_STAGE("pre_q"); - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_qr, - model->map, - model->size, - layer->attn_q_a->abs_offset, - DS4_N_EMBD, - q_rank, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_q_a", + il, + pos0, + g->batch_qr, + model, + layer->attn_q_a, + DS4_N_EMBD, + q_rank, + g->batch_attn_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("q_lora", g->batch_qr, (uint64_t)n_tokens * q_rank, il, pos0); } DS4_METAL_PROFILE_Q_STAGE("q_a"); if (qkv_rms_fused) { - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_kv_raw, - model->map, - model->size, - layer->attn_kv->abs_offset, - DS4_N_EMBD, - DS4_N_HEAD_DIM, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_kv", + il, + pos0, + g->batch_kv_raw, + model, + layer->attn_kv, + DS4_N_EMBD, + DS4_N_HEAD_DIM, + g->batch_attn_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("KVraw", g->batch_kv_raw, (uint64_t)n_tokens * DS4_N_HEAD_DIM, il, pos0); @@ -11142,14 +11305,16 @@ static bool metal_graph_encode_layer_attention_batch( (uint64_t)n_tokens * DS4_N_HEAD_DIM, il, pos0); } DS4_METAL_PROFILE_Q_STAGE("q_a_norm"); - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_q, - model->map, - model->size, - layer->attn_q_b->abs_offset, - q_rank, - q_dim, - g->batch_qr_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_q_b", + il, + pos0, + g->batch_q, + model, + layer->attn_q_b, + q_rank, + q_dim, + g->batch_qr_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("Qraw", g->batch_q, (uint64_t)n_tokens * q_dim, il, pos0); @@ -11186,14 +11351,16 @@ static bool metal_graph_encode_layer_attention_batch( DS4_METAL_PROFILE_Q_STAGE("rope"); DS4_METAL_PROFILE_ATTN_STAGE("q_path"); if (!qkv_rms_fused) { - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_kv_raw, - model->map, - model->size, - layer->attn_kv->abs_offset, - DS4_N_EMBD, - DS4_N_HEAD_DIM, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_kv", + il, + pos0, + g->batch_kv_raw, + model, + layer->attn_kv, + DS4_N_EMBD, + DS4_N_HEAD_DIM, + g->batch_attn_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("KVraw", g->batch_kv_raw, (uint64_t)n_tokens * DS4_N_HEAD_DIM, il, pos0); @@ -11320,27 +11487,39 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal layer-major prefill needs attention compressor weights\n"); ok = false; } - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, - model->map, - model->size, - layer->attn_compressor_kv->abs_offset, - DS4_N_EMBD, - comp_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok && ds4_env_bool_enabled("DS4_METAL_MPP_F16_PAIR")) { + ok = metal_graph_matmul_f16_pair_or_separate(g->batch_comp_kv, + g->batch_comp_sc, + model, + layer->attn_compressor_kv->abs_offset, + layer->attn_compressor_gate->abs_offset, + DS4_N_EMBD, + comp_width, + g->batch_attn_norm, + n_tokens); + } else if (ok) { + ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, + model->map, + model->size, + layer->attn_compressor_kv->abs_offset, + DS4_N_EMBD, + comp_width, + g->batch_attn_norm, + n_tokens) != 0; + if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, + model->map, + model->size, + layer->attn_compressor_gate->abs_offset, + DS4_N_EMBD, + comp_width, + g->batch_attn_norm, + n_tokens) != 0; + } if (ok) metal_graph_debug_dump_tensor("attn_comp_kv_raw", g->batch_comp_kv, (uint64_t)comp_width * n_tokens, il, pos0); - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, - model->map, - model->size, - layer->attn_compressor_gate->abs_offset, - DS4_N_EMBD, - comp_width, - g->batch_attn_norm, - n_tokens) != 0; if (ok) metal_graph_debug_dump_tensor("attn_comp_score_raw", g->batch_comp_sc, (uint64_t)comp_width * n_tokens, @@ -11598,27 +11777,39 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal layer-major prefill needs indexer weights\n"); ok = false; } - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, - model->map, - model->size, - layer->indexer_compressor_kv->abs_offset, - DS4_N_EMBD, - index_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok && ds4_env_bool_enabled("DS4_METAL_MPP_F16_PAIR")) { + ok = metal_graph_matmul_f16_pair_or_separate(g->batch_comp_kv, + g->batch_comp_sc, + model, + layer->indexer_compressor_kv->abs_offset, + layer->indexer_compressor_gate->abs_offset, + DS4_N_EMBD, + index_width, + g->batch_attn_norm, + n_tokens); + } else if (ok) { + ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, + model->map, + model->size, + layer->indexer_compressor_kv->abs_offset, + DS4_N_EMBD, + index_width, + g->batch_attn_norm, + n_tokens) != 0; + if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, + model->map, + model->size, + layer->indexer_compressor_gate->abs_offset, + DS4_N_EMBD, + index_width, + g->batch_attn_norm, + n_tokens) != 0; + } if (ok) metal_graph_debug_dump_tensor("indexer_comp_kv_raw", g->batch_comp_kv, (uint64_t)index_width * n_tokens, il, pos0); - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, - model->map, - model->size, - layer->indexer_compressor_gate->abs_offset, - DS4_N_EMBD, - index_width, - g->batch_attn_norm, - n_tokens) != 0; if (ok) metal_graph_debug_dump_tensor("indexer_comp_score_raw", g->batch_comp_sc, (uint64_t)index_width * n_tokens, @@ -12237,20 +12428,24 @@ static bool metal_graph_encode_layer_attention_batch( (uint64_t)n_tokens * q_dim, il, pos0); } DS4_METAL_PROFILE_ATTN_STAGE("inv_rope"); - if (ok) ok = ds4_gpu_attention_output_q8_batch_tensor(g->batch_attn_out, - g->batch_attn_low, - g->batch_group_tmp, - g->batch_low_tmp, - model->map, - model->size, - layer->attn_output_a->abs_offset, - layer->attn_output_b->abs_offset, - group_dim, - rank, - n_groups, - DS4_N_EMBD, - g->batch_heads, - n_tokens) != 0; + if (ok) { + ds4_gpu_set_mpp_compare_context("attn_out", il, pos0); + ok = ds4_gpu_attention_output_q8_batch_tensor(g->batch_attn_out, + g->batch_attn_low, + g->batch_group_tmp, + g->batch_low_tmp, + model->map, + model->size, + layer->attn_output_a->abs_offset, + layer->attn_output_b->abs_offset, + group_dim, + rank, + n_groups, + DS4_N_EMBD, + g->batch_heads, + n_tokens) != 0; + ds4_gpu_clear_mpp_compare_context(); + } if (ok) { metal_graph_debug_dump_tensor("attn_low", g->batch_attn_low, (uint64_t)n_tokens * n_groups * rank, @@ -12422,32 +12617,37 @@ static bool metal_graph_encode_layer_ffn_batch( } DS4_METAL_PROFILE_FFN_STAGE("router"); - if (ok) ok = ds4_gpu_routed_moe_batch_tensor(g->batch_routed_out, - g->batch_routed_gate, - g->batch_routed_up, - g->batch_routed_mid, - g->batch_routed_down, - model->map, - model->size, - layer->ffn_gate_exps->abs_offset, - layer->ffn_up_exps->abs_offset, - layer->ffn_down_exps->abs_offset, - layer->ffn_gate_exps->type, - layer->ffn_down_exps->type, - gate_expert_bytes, - gate_row_bytes, - down_expert_bytes, - down_row_bytes, - (uint32_t)expert_in_dim, - (uint32_t)down_in_dim, - (uint32_t)routed_out_dim, - g->batch_router_selected, - g->batch_router_weights, - DS4_N_EXPERT_USED, - DS4_SWIGLU_CLAMP_EXP, - g->batch_ffn_norm, - n_tokens, - &g->batch_routed_mid_is_f16) != 0; + if (ok) { + ds4_gpu_set_mpp_compare_context("routed_moe", il, pos0); + ok = ds4_gpu_routed_moe_batch_tensor(g->batch_routed_out, + g->batch_routed_gate, + g->batch_routed_up, + g->batch_routed_mid, + g->batch_routed_down, + model->map, + model->size, + layer->ffn_gate_exps->abs_offset, + layer->ffn_up_exps->abs_offset, + layer->ffn_down_exps->abs_offset, + layer->ffn_gate_exps->type, + layer->ffn_down_exps->type, + gate_expert_bytes, + gate_row_bytes, + down_expert_bytes, + down_row_bytes, + (uint32_t)expert_in_dim, + (uint32_t)down_in_dim, + (uint32_t)routed_out_dim, + g->batch_router_selected, + g->batch_router_weights, + DS4_N_EXPERT_USED, + DS4_SWIGLU_CLAMP_EXP, + g->batch_ffn_norm, + il, + n_tokens, + &g->batch_routed_mid_is_f16) != 0; + ds4_gpu_clear_mpp_compare_context(); + } if (ok) { metal_graph_debug_dump_tensor("ffn_moe_gate_clamped", g->batch_routed_gate, (uint64_t)n_tokens * DS4_N_EXPERT_USED * down_in_dim, il, pos0); @@ -12467,22 +12667,26 @@ static bool metal_graph_encode_layer_ffn_batch( (uint64_t)n_tokens * DS4_N_EMBD, il, pos0); } DS4_METAL_PROFILE_FFN_STAGE("routed_moe"); - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_shared_gate, - model->map, - model->size, - layer->ffn_gate_shexp->abs_offset, - DS4_N_EMBD, - shared_dim, - g->batch_ffn_norm, - n_tokens) != 0; - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_shared_up, - model->map, - model->size, - layer->ffn_up_shexp->abs_offset, - DS4_N_EMBD, - shared_dim, - g->batch_ffn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("shared_gate", + il, + pos0, + g->batch_shared_gate, + model, + layer->ffn_gate_shexp, + DS4_N_EMBD, + shared_dim, + g->batch_ffn_norm, + n_tokens); + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("shared_up", + il, + pos0, + g->batch_shared_up, + model, + layer->ffn_up_shexp, + DS4_N_EMBD, + shared_dim, + g->batch_ffn_norm, + n_tokens); DS4_METAL_PROFILE_FFN_STAGE("shared_gate_up"); if (ok) ok = ds4_gpu_swiglu_tensor(g->batch_shared_mid, g->batch_shared_gate, @@ -12490,14 +12694,16 @@ static bool metal_graph_encode_layer_ffn_batch( (uint32_t)((uint64_t)n_tokens * shared_dim), 0.0f, 1.0f) != 0; - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_shared_out, - model->map, - model->size, - layer->ffn_down_shexp->abs_offset, - shared_dim, - DS4_N_EMBD, - g->batch_shared_mid, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("shared_down", + il, + pos0, + g->batch_shared_out, + model, + layer->ffn_down_shexp, + shared_dim, + DS4_N_EMBD, + g->batch_shared_mid, + n_tokens); DS4_METAL_PROFILE_FFN_STAGE("shared_down"); if (ok) { metal_graph_debug_dump_tensor("ffn_shexp", g->batch_shared_out, @@ -13020,16 +13226,19 @@ static bool metal_graph_prefill_layer_major( const ds4_model *model, const ds4_weights *weights, const token_vec *prompt, - int n_tokens, + uint32_t start, + uint32_t n_tokens, float *logits, bool show_progress, ds4_imatrix_collector *imatrix) { - if (n_tokens <= 0 || n_tokens > prompt->len || (uint32_t)n_tokens > g->prefill_cap) return false; + if (n_tokens == 0 || n_tokens > g->prefill_cap) return false; + if (start > (uint32_t)prompt->len) return false; + if (n_tokens > (uint32_t)prompt->len - start) return false; - bool ok = metal_graph_upload_prompt_tokens(g->prefill_tokens, prompt, 0, (uint32_t)n_tokens); + bool ok = metal_graph_upload_prompt_tokens(g->prefill_tokens, prompt, start, n_tokens); if (!ok) return false; - if (!metal_graph_warmup_prefill_kernels(g, model, weights, (uint32_t)n_tokens)) return false; + if (!metal_graph_warmup_prefill_kernels(g, model, weights, n_tokens)) return false; const bool split_profile = getenv("DS4_METAL_GRAPH_PREFILL_SPLIT_PROFILE") != NULL; /* @@ -13050,16 +13259,16 @@ static bool metal_graph_prefill_layer_major( model, weights, prompt, - 0, - (uint32_t)n_tokens); + start, + n_tokens); if (ok) ok = ds4_gpu_begin_commands() != 0; for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { ok = metal_graph_encode_layer_batch(g, model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); if (show_progress) { fprintf(stderr, "ds4: gpu prefill layer %u/%u\r", il + 1, (uint32_t)DS4_N_LAYER); fflush(stderr); @@ -13077,13 +13286,13 @@ static bool metal_graph_prefill_layer_major( output_row = (uint32_t)v; } } - ds4_gpu_tensor *last_hc = NULL; ds4_gpu_tensor *saved_cur = g->cur_hc; - if (ok) { + ds4_gpu_tensor *last_hc = NULL; + if (ok && logits) { last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, output_row, hc_dim); ok = last_hc != NULL; } - if (ok) { + if (ok && logits) { g->cur_hc = last_hc; ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); g->cur_hc = saved_cur; @@ -13108,7 +13317,7 @@ static bool metal_graph_prefill_layer_major( if (profile) { const double t_read = now_sec(); fprintf(stderr, - "ds4: gpu graph prefill total tokens=%d encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", + "ds4: gpu graph prefill total tokens=%u encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", n_tokens, (t_encoded - t0) * 1000.0, (t_done - t_encoded) * 1000.0, @@ -13124,8 +13333,8 @@ static bool metal_graph_prefill_layer_major( model, weights, prompt, - 0, - (uint32_t)n_tokens); + start, + n_tokens); const double t_embed_encoded = profile ? now_sec() : 0.0; const double t_embed_done = profile ? now_sec() : 0.0; if (profile) { @@ -13153,8 +13362,8 @@ static bool metal_graph_prefill_layer_major( model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); const double t_attn_encoded = now_sec(); if (ok) ok = ds4_gpu_end_commands() != 0; const double t_attn_done = now_sec(); @@ -13165,8 +13374,8 @@ static bool metal_graph_prefill_layer_major( model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); if (ok) { ds4_gpu_tensor *tmp = g->batch_cur_hc; g->batch_cur_hc = g->batch_next_hc; @@ -13193,8 +13402,8 @@ static bool metal_graph_prefill_layer_major( model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); const double t_encoded = profile ? now_sec() : 0.0; if (ok) ok = ds4_gpu_end_commands() != 0; const double t_done = profile ? now_sec() : 0.0; @@ -13232,21 +13441,26 @@ static bool metal_graph_prefill_layer_major( output_row = (uint32_t)v; } } - ds4_gpu_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, - output_row, - hc_dim); - if (!last_hc) return false; ds4_gpu_tensor *saved_cur = g->cur_hc; - g->cur_hc = last_hc; + ds4_gpu_tensor *last_hc = NULL; const double t_head0 = profile ? now_sec() : 0.0; - ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); + if (logits) { + last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, + output_row, + hc_dim); + ok = last_hc != NULL; + } + if (ok && logits) { + g->cur_hc = last_hc; + ok = ds4_gpu_begin_commands() != 0; + } + if (ok && logits) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); const double t_head_encoded = profile ? now_sec() : 0.0; - if (ok) ok = ds4_gpu_end_commands() != 0; + if (ok && logits) ok = ds4_gpu_end_commands() != 0; const double t_head_done = profile ? now_sec() : 0.0; g->cur_hc = saved_cur; - ds4_gpu_tensor_free(last_hc); + if (last_hc) ds4_gpu_tensor_free(last_hc); if (!ok) return false; const double t_before_read = profile ? now_sec() : 0.0; @@ -13264,7 +13478,7 @@ static bool metal_graph_prefill_layer_major( (t_head_done - t_head_encoded) * 1000.0); } fprintf(stderr, - "ds4: gpu layer-major prefill total tokens=%d encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", + "ds4: gpu layer-major prefill total tokens=%u encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", n_tokens, encode_s * 1000.0, execute_s * 1000.0, @@ -13284,32 +13498,15 @@ static bool metal_graph_prefill_raw_swa( bool show_progress) { if (n_tokens <= 0 || n_tokens > prompt->len) return false; if ((uint32_t)n_tokens > g->prefill_cap) return false; - return metal_graph_prefill_layer_major(g, model, weights, prompt, n_tokens, logits, show_progress, NULL); -} - -static bool metal_graph_prefill_batch_row_logits( - ds4_gpu_graph *g, - const ds4_model *model, - const ds4_weights *weights, - uint32_t batch_row, - float *logits) { - if (!logits) return true; - const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; - ds4_gpu_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, - batch_row, - hc_dim); - if (!last_hc) return false; - ds4_gpu_tensor *saved_cur = g->cur_hc; - g->cur_hc = last_hc; - bool ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); - if (ok) ok = ds4_gpu_end_commands() != 0; - else (void)ds4_gpu_synchronize(); - g->cur_hc = saved_cur; - ds4_gpu_tensor_free(last_hc); - if (!ok) return false; - return ds4_gpu_tensor_read(g->logits, 0, logits, - (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0; + return metal_graph_prefill_layer_major(g, + model, + weights, + prompt, + 0, + (uint32_t)n_tokens, + logits, + show_progress, + NULL); } /* Prefill a contiguous token range in fixed-size chunks. @@ -13340,21 +13537,8 @@ static bool metal_graph_prefill_chunked_range( if (start != 0 && chunk_cap > g->raw_cap) chunk_cap = g->raw_cap; if (chunk_cap == 0) return false; - uint32_t first_chunk = n_tokens < chunk_cap ? n_tokens : chunk_cap; - if (start != 0 && g->prefill_cap != 0) { - const uint32_t mod = start % g->prefill_cap; - if (mod != 0) { - const uint32_t to_boundary = g->prefill_cap - mod; - if (to_boundary < first_chunk) first_chunk = to_boundary; - } - } - if (!metal_graph_warmup_prefill_kernels(g, model, weights, first_chunk)) return false; - const bool profile = getenv("DS4_METAL_GRAPH_PREFILL_PROFILE") != NULL; const double t0 = profile ? now_sec() : 0.0; - double encode_s = 0.0; - double execute_s = 0.0; - uint32_t last_chunk_tokens = 0; const uint32_t end = start + n_tokens; if (progress) { @@ -13372,109 +13556,39 @@ static bool metal_graph_prefill_chunked_range( } } const uint32_t chunk = remaining < local_cap ? remaining : local_cap; - last_chunk_tokens = chunk; - - bool ok = metal_graph_upload_prompt_tokens(g->prefill_tokens, prompt, pos0, chunk); - if (ok) ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, - g->prefill_tokens, - model, - weights, - prompt, - pos0, - chunk); - if (!ok) return false; - - for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { - const double t_layer0 = profile ? now_sec() : 0.0; - ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_layer_batch(g, - model, - &weights->layer[il], - il, - pos0, - chunk); - const double t_encoded = profile ? now_sec() : 0.0; - if (ok) ok = ds4_gpu_end_commands() != 0; - const double t_done = profile ? now_sec() : 0.0; - if (ok && imatrix) ok = imatrix_collect_layer_batch(imatrix, g, il, chunk); - if (profile) { - encode_s += t_encoded - t_layer0; - execute_s += t_done - t_encoded; - fprintf(stderr, - "ds4: gpu chunked prefill pos=%u tokens=%u layer %u encode=%.3f ms execute=%.3f ms\n", - pos0, - chunk, - il, - (t_encoded - t_layer0) * 1000.0, - (t_done - t_encoded) * 1000.0); - } - if (show_progress) { - fprintf(stderr, - "ds4: gpu prefill token %u/%u layer %u/%u\r", - pos0 + chunk, - (uint32_t)prompt->len, - il + 1, - (uint32_t)DS4_N_LAYER); - fflush(stderr); - } - } + const uint32_t chunk_end = pos0 + chunk; + float *chunk_logits = (progress || chunk_end == end) ? logits : NULL; + bool ok = metal_graph_prefill_layer_major(g, + model, + weights, + prompt, + pos0, + chunk, + chunk_logits, + show_progress, + imatrix); if (!ok) { if (ds4_gpu_synchronize() == 0) { fprintf(stderr, "ds4: Metal synchronize after chunked prefill failure also failed\n"); } return false; } - if (progress && !metal_graph_prefill_batch_row_logits(g, model, weights, - chunk - 1u, - logits)) - { - return false; - } if (progress) { - progress(progress_ud, "prefill_chunk", (int)(pos0 + chunk), prompt->len); + progress(progress_ud, "prefill_chunk", (int)chunk_end, prompt->len); } - pos0 += chunk; + pos0 = chunk_end; } if (show_progress) fputc('\n', stderr); - if (last_chunk_tokens == 0) return false; - - const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; - ds4_gpu_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, - last_chunk_tokens - 1u, - hc_dim); - if (!last_hc) return false; - ds4_gpu_tensor *saved_cur = g->cur_hc; - g->cur_hc = last_hc; - - const double t_head0 = profile ? now_sec() : 0.0; - bool ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); - const double t_head_encoded = profile ? now_sec() : 0.0; - if (ok) ok = ds4_gpu_end_commands() != 0; - const double t_head_done = profile ? now_sec() : 0.0; - g->cur_hc = saved_cur; - ds4_gpu_tensor_free(last_hc); - if (!ok) return false; - - const double t_before_read = profile ? now_sec() : 0.0; - if (logits) { - ok = ds4_gpu_tensor_read(g->logits, 0, logits, (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0; - } if (profile) { const double t_read = now_sec(); - encode_s += t_head_encoded - t_head0; - execute_s += t_head_done - t_head_encoded; fprintf(stderr, - "ds4: gpu chunked prefill start=%u tokens=%u chunk=%u encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", + "ds4: gpu chunked prefill start=%u tokens=%u chunk=%u total=%.3f ms\n", start, n_tokens, chunk_cap, - encode_s * 1000.0, - execute_s * 1000.0, - (t_read - t_before_read) * 1000.0, (t_read - t0) * 1000.0); } - return ok; + return true; } /* Long prompts are prefetched in fixed-size chunks. Chunks bound transient @@ -13772,7 +13886,7 @@ static uint32_t metal_graph_raw_cap_for_context(int ctx_size, uint32_t prefill_c } /* Choose the prefill ubatch size. Whole-batch is fastest for normal prompts; - * long prompts default to 2048-token chunks. */ + * long prompts default to 4096-token chunks. */ static uint32_t metal_graph_prefill_cap_for_prompt(int prompt_len) { return ds4_default_prefill_cap_for_prompt(prompt_len); } @@ -14176,6 +14290,7 @@ struct ds4_engine { float *directional_steering_dirs; float directional_steering_attn_scale; float directional_steering_ffn_scale; + ds4_mpp_mode mpp_mode; bool quality; bool metal_ready; bool mtp_ready; @@ -15417,6 +15532,15 @@ const char *ds4_backend_name(ds4_backend backend) { return "unknown"; } +const char *ds4_mpp_mode_name(ds4_mpp_mode mode) { + switch (mode) { + case DS4_MPP_AUTO: return "auto"; + case DS4_MPP_ON: return "on"; + case DS4_MPP_OFF: return "off"; + } + return "unknown"; +} + bool ds4_think_mode_enabled(ds4_think_mode mode) { return mode == DS4_THINK_HIGH || mode == DS4_THINK_MAX; } @@ -16668,7 +16792,8 @@ int ds4_engine_collect_imatrix(ds4_engine *e, &collector); } else { ok = metal_graph_prefill_layer_major(&g, model, weights, - &prompt, prompt.len, + &prompt, 0, + (uint32_t)prompt.len, NULL, false, &collector); } @@ -16953,6 +17078,7 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { e->mtp_model.fd = -1; e->backend = opt->backend; e->quality = opt->quality; + e->mpp_mode = opt->mpp_mode; e->mtp_draft_tokens = opt->mtp_draft_tokens > 0 ? opt->mtp_draft_tokens : 1; if (e->mtp_draft_tokens > 16) e->mtp_draft_tokens = 16; e->mtp_margin = opt->mtp_margin >= 0.0f ? opt->mtp_margin : 3.0f; @@ -17018,6 +17144,7 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { *out = NULL; return 1; } + ds4_gpu_set_mpp_mode(e->mpp_mode); ds4_gpu_set_quality(e->quality); (void)ds4_gpu_set_model_fd(e->model.fd); if (!ds4_gpu_set_model_map_range(e->model.map, @@ -17075,6 +17202,10 @@ void ds4_engine_summary(ds4_engine *e) { model_summary(&e->model); } +int ds4_engine_vocab_size(ds4_engine *e) { + return e ? e->vocab.n_vocab : 0; +} + void ds4_engine_close(ds4_engine *e) { if (!e) return; weights_free(&e->weights); @@ -17484,6 +17615,12 @@ int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out) { return 1; } +int ds4_session_copy_logits(ds4_session *s, float *out, int cap) { + if (!s || !out || cap < (int)DS4_N_VOCAB) return 0; + memcpy(out, s->logits, (size_t)DS4_N_VOCAB * sizeof(out[0])); + return (int)DS4_N_VOCAB; +} + static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, char *err, size_t errlen) { if (!s) return 1; diff --git a/ds4.h b/ds4.h index 950d8dca..c60105f7 100644 --- a/ds4.h +++ b/ds4.h @@ -20,6 +20,12 @@ typedef enum { DS4_BACKEND_CPU, } ds4_backend; +typedef enum { + DS4_MPP_AUTO = 0, + DS4_MPP_ON, + DS4_MPP_OFF, +} ds4_mpp_mode; + typedef enum { DS4_THINK_NONE, DS4_THINK_HIGH, @@ -67,6 +73,7 @@ typedef struct { float directional_steering_ffn; bool warm_weights; bool quality; + ds4_mpp_mode mpp_mode; } ds4_engine_options; typedef void (*ds4_token_emit_fn)(void *ud, int token); @@ -91,7 +98,9 @@ typedef struct { int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt); void ds4_engine_close(ds4_engine *e); void ds4_engine_summary(ds4_engine *e); +int ds4_engine_vocab_size(ds4_engine *e); const char *ds4_backend_name(ds4_backend backend); +const char *ds4_mpp_mode_name(ds4_mpp_mode mode); bool ds4_think_mode_enabled(ds4_think_mode mode); const char *ds4_think_mode_name(ds4_think_mode mode); const char *ds4_think_max_prefix(void); @@ -168,6 +177,7 @@ int ds4_session_argmax_excluding(ds4_session *s, int excluded_id); int ds4_session_sample(ds4_session *s, float temperature, int top_k, float top_p, float min_p, uint64_t *rng); int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k); int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out); +int ds4_session_copy_logits(ds4_session *s, float *out, int cap); int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen); int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token, int max_tokens, int eos_token, diff --git a/ds4_cli.c b/ds4_cli.c index bc70e659..887e4b1e 100644 --- a/ds4_cli.c +++ b/ds4_cli.c @@ -32,6 +32,7 @@ typedef struct { float top_p; uint64_t seed; bool dump_tokens; + const char *dump_logits_path; const char *dump_logprobs_path; int dump_logprobs_top_k; const char *imatrix_dataset_path; @@ -102,7 +103,10 @@ static void usage(FILE *fp) { " -t, --threads N\n" " CPU helper threads for host-side or reference work.\n" " --quality\n" - " Prefer exact kernels where faster approximate paths exist; MTP uses strict verification.\n" + " Prefer exact kernels where faster approximate paths exist; disables Metal Tensor routes; MTP uses strict verification.\n" + " -mt MODE, --mt MODE\n" + " Metal Tensor policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" + " Legacy alias: --mpp MODE.\n" " --dir-steering-file FILE\n" " Load one f32 direction vector per layer for directional steering.\n" " --dir-steering-ffn F\n" @@ -153,6 +157,8 @@ static void usage(FILE *fp) { " Load the model and print a summary only.\n" " --dump-tokens\n" " Tokenize -p/--prompt-file exactly as written, then exit without inference.\n" + " --dump-logits FILE\n" + " Write full next-token logits as JSON after prompt prefill, then exit.\n" " --dump-logprobs FILE\n" " Write greedy continuation top-logprobs as JSON without printing text.\n" " --logprobs-top-k N\n" @@ -240,6 +246,15 @@ static ds4_backend default_backend(void) { #endif } +static ds4_mpp_mode parse_mpp_mode(const char *s) { + if (!strcmp(s, "auto")) return DS4_MPP_AUTO; + if (!strcmp(s, "on")) return DS4_MPP_ON; + if (!strcmp(s, "off")) return DS4_MPP_OFF; + fprintf(stderr, "ds4: invalid Metal Tensor mode: %s\n", s); + fprintf(stderr, "ds4: valid Metal Tensor modes are: auto, on, off\n"); + exit(2); +} + static void log_context_memory(ds4_backend backend, int ctx_size) { ds4_context_memory m = ds4_context_memory_estimate(backend, ctx_size); fprintf(stderr, @@ -629,6 +644,86 @@ static void json_write_token(FILE *fp, ds4_engine *engine, int token) { free(text); } +static int run_logits_dump(ds4_engine *engine, const cli_config *cfg, const ds4_tokens *prompt) { + ds4_session *session = NULL; + if (ds4_session_create(&session, engine, cfg->gen.ctx_size) != 0) { + fprintf(stderr, "ds4: --dump-logits requires a graph session backend\n"); + return 1; + } + + char err[160]; + cli_prefill_progress progress = { + .base_tokens = 0, + .input_tokens = prompt->len, + .use_color = ds4_log_is_tty(stderr), + }; + ds4_session_set_progress(session, cli_prefill_progress_cb, &progress); + if (ds4_session_sync(session, prompt, err, sizeof(err)) != 0) { + ds4_session_set_progress(session, NULL, NULL); + fprintf(stderr, "ds4: prompt processing failed: %s\n", err); + ds4_session_free(session); + return 1; + } + ds4_session_set_progress(session, NULL, NULL); + + const int vocab = ds4_engine_vocab_size(engine); + float *logits = malloc((size_t)vocab * sizeof(logits[0])); + if (!logits) { + ds4_session_free(session); + return 1; + } + if (ds4_session_copy_logits(session, logits, vocab) != vocab) { + fprintf(stderr, "ds4: failed to copy session logits\n"); + free(logits); + ds4_session_free(session); + return 1; + } + + FILE *fp = fopen(cfg->gen.dump_logits_path, "wb"); + if (!fp) { + fprintf(stderr, "ds4: failed to open --dump-logits file: %s\n", cfg->gen.dump_logits_path); + free(logits); + ds4_session_free(session); + return 1; + } + + fprintf(fp, "{\n \"source\":\"ds4\",\n \"model\":"); + json_write_string(fp, cfg->engine.model_path, strlen(cfg->engine.model_path)); + fprintf(fp, + ",\n \"backend\":\"%s\",\n \"mt\":\"%s\",\n \"quant_bits\":%d,\n" + " \"prompt_tokens\":%d,\n \"ctx\":%d,\n \"vocab\":%d,\n", + ds4_backend_name(cfg->engine.backend), + ds4_mpp_mode_name(cfg->engine.mpp_mode), + ds4_engine_routed_quant_bits(engine), + prompt->len, + cfg->gen.ctx_size, + vocab); + const int argmax = ds4_session_argmax(session); + fputs(" \"argmax_token\":", fp); + json_write_token(fp, engine, argmax); + fprintf(fp, ",\n \"argmax_logit\":%.9g,\n \"logits\":[", logits[argmax]); + for (int i = 0; i < vocab; i++) { + if (i) fputc(',', fp); + if ((i % 8) == 0) fputs("\n ", fp); + if (isfinite(logits[i])) { + fprintf(fp, "%.9g", logits[i]); + } else { + fputs("null", fp); + } + } + fputs("\n ]\n}\n", fp); + if (fclose(fp) != 0) { + fprintf(stderr, "ds4: failed to close --dump-logits file: %s\n", cfg->gen.dump_logits_path); + free(logits); + ds4_session_free(session); + return 1; + } + + free(logits); + ds4_session_free(session); + return 0; +} + static int run_logprob_dump(ds4_engine *engine, const cli_config *cfg, const ds4_tokens *prompt) { ds4_session *session = NULL; if (ds4_session_create(&session, engine, cfg->gen.ctx_size) != 0) { @@ -730,6 +825,11 @@ static int run_generation(ds4_engine *engine, const cli_config *cfg) { ds4_tokens_free(&prompt); return rc; } + if (cfg->gen.dump_logits_path) { + rc = run_logits_dump(engine, cfg, &prompt); + ds4_tokens_free(&prompt); + return rc; + } if (cfg->gen.dump_logprobs_path) { rc = run_logprob_dump(engine, cfg, &prompt); ds4_tokens_free(&prompt); @@ -1244,6 +1344,8 @@ static cli_config parse_options(int argc, char **argv) { c.gen.seed = parse_u64(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--quality")) { c.engine.quality = true; + } else if (!strcmp(arg, "-mt") || !strcmp(arg, "--mt") || !strcmp(arg, "--mpp")) { + c.engine.mpp_mode = parse_mpp_mode(need_arg(&i, argc, argv, arg)); } else if (!strcmp(arg, "--dir-steering-file")) { c.engine.directional_steering_file = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--dir-steering-ffn")) { @@ -1264,6 +1366,8 @@ static cli_config parse_options(int argc, char **argv) { c.engine.backend = DS4_BACKEND_CUDA; } else if (!strcmp(arg, "--dump-tokens")) { c.gen.dump_tokens = true; + } else if (!strcmp(arg, "--dump-logits")) { + c.gen.dump_logits_path = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--dump-logprobs")) { c.gen.dump_logprobs_path = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--logprobs-top-k")) { diff --git a/ds4_gpu.h b/ds4_gpu.h index 2d16c9c9..b000af9f 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -4,6 +4,8 @@ #include #include +#include "ds4.h" + /* ========================================================================= * GPU Tensor and Command Lifetime. * ========================================================================= @@ -41,6 +43,9 @@ int ds4_gpu_set_model_map_range(const void *model_map, uint64_t model_size, uint int ds4_gpu_cache_model_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, const char *label); int ds4_gpu_cache_q8_f16_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, uint64_t in_dim, uint64_t out_dim, const char *label); void ds4_gpu_set_quality(bool quality); +void ds4_gpu_set_mpp_mode(ds4_mpp_mode mode); +void ds4_gpu_set_mpp_compare_context(const char *module, uint32_t layer_index, uint32_t pos0); +void ds4_gpu_clear_mpp_compare_context(void); void ds4_gpu_print_memory_report(const char *label); /* ========================================================================= @@ -139,6 +144,16 @@ int ds4_gpu_matmul_q8_0_tensor( const ds4_gpu_tensor *x, uint64_t n_tok); +int ds4_gpu_matmul_q8_0_mpp_tensor( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok); + int ds4_gpu_shared_gate_up_swiglu_q8_0_tensor( ds4_gpu_tensor *gate, ds4_gpu_tensor *up, @@ -665,6 +680,7 @@ int ds4_gpu_routed_moe_batch_tensor( uint32_t n_expert, float clamp, const ds4_gpu_tensor *x, + uint32_t layer_index, uint32_t n_tokens, bool *mid_is_f16); diff --git a/ds4_metal.m b/ds4_metal.m index 0a6ae748..c03925fa 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,7 @@ static id g_cpy_f16_f32_pipeline; static id g_swiglu_pipeline; static id g_add_pipeline; +static id g_moe_sum6_pipeline; static id g_mul_pipeline; static id g_rms_norm_pipeline; static id g_rms_norm_plain_pipeline; @@ -76,9 +78,6 @@ static id g_moe_mul_mv_id_q4_k_pair_pipeline; static id g_moe_mul_mv_id_q4_k_pair_swiglu_pipeline; static id g_moe_mul_mv_id_q4_k_sum6_pipeline; -static id g_moe_mul_mm_id_iq2_xxs_pipeline; -static id g_moe_mul_mm_id_q2_k_pipeline; -static id g_moe_mul_mm_id_q4_k_pipeline; static id g_rope_tail_batch_pipeline; static id g_dsv4_fp8_kv_quantize_pipeline; static id g_dsv4_kv_fp8_store_pipeline; @@ -97,6 +96,7 @@ static id g_dsv4_sort_i32_rows_asc_pipeline; static id g_dsv4_indexed_attention_heads8_pipeline; static id g_dsv4_indexed_attention_heads8_rb4_pipeline; +static id g_dsv4_indexed_attention_heads8_rb16_pipeline; static id g_dsv4_softplus_sqrt_pipeline; static id g_dsv4_router_finalize_one_pipeline; static id g_dsv4_router_weights_one_pipeline; @@ -140,6 +140,13 @@ static uint64_t g_model_wrap_bytes; static uint64_t g_model_wrap_max_bytes; static uint64_t g_model_residency_count; +static int g_metal4_runtime_available; +static int g_metal4_family_supported; +static int g_metal4_queue_supported; +static int g_metal4_m5_neural_accelerators_hint; +static int g_metal4_tensor_api_enabled; +static int g_metal4_tensor_api_compile_supported; +static char g_metal_device_name[128]; static NSUInteger g_flash_attn_mask_bytes; static NSUInteger g_flash_attn_pad_bytes; static NSUInteger g_flash_attn_tmp_bytes; @@ -167,6 +174,38 @@ static NSUInteger g_attn_out_group_ids_bytes; static int g_initialized; static int g_quality_mode; +static ds4_mpp_mode g_mpp_mode = DS4_MPP_AUTO; +static int g_mpp_q8_reported; +static int g_mpp_q8_partial_skip_reported; +static int g_mpp_f16_reported; +static int g_mpp_f16_pair_reported; +static int g_mpp_attn_out_reported; +static int g_mpp_moe_reported; +static int g_mpp_moe_ranges_reported; +static int g_mpp_invalid_env_reported; +static char g_mpp_compare_context[128]; + +#define DS4_METAL_MPP_COMPARE_PENDING_MAX 64 +#define DS4_METAL_MPP_COMPARE_DELTAS 5 + +typedef struct { + __strong id ref_buffer; + __strong id cand_buffer; + NSUInteger ref_offset; + NSUInteger cand_offset; + uint64_t elements; + uint64_t dim0; + uint64_t dim1; + uint64_t dim2; + char route[16]; + char label[128]; +} ds4_gpu_mpp_compare_item; + +static ds4_gpu_mpp_compare_item g_mpp_compare_pending[DS4_METAL_MPP_COMPARE_PENDING_MAX]; +static int g_mpp_compare_pending_count; +static int g_mpp_compare_done_count; +static int g_mpp_compare_stopped; +static int g_mpp_compare_limit_reported; static uint64_t ds4_gpu_system_memory_bytes(void) { uint64_t bytes = 0; @@ -278,12 +317,260 @@ static int ds4_gpu_wait_pending_command_buffers(const char *label) { return ok; } +static int ds4_gpu_mpp_compare_max(void) { + const char *env = getenv("DS4_METAL_MPP_COMPARE_MAX"); + if (!env || !env[0]) return 20; + char *end = NULL; + unsigned long v = strtoul(env, &end, 10); + if (end == env) return 20; + if (v > 1000000ul) v = 1000000ul; + return (int)v; +} + +static int ds4_gpu_mpp_compare_verbose(void) { + const char *env = getenv("DS4_METAL_MPP_COMPARE_VERBOSE"); + return env && env[0] && strcmp(env, "0") != 0 && + strcmp(env, "false") != 0 && strcmp(env, "off") != 0; +} + +static int ds4_gpu_mpp_compare_route_matches(const char *route) { + if (g_mpp_compare_stopped) return 0; + const char *want = getenv("DS4_METAL_MPP_COMPARE_ROUTE"); + if (!want || !want[0] || !route || !route[0]) return 0; + if (strcmp(want, "all") == 0) return 1; + return strcmp(want, route) == 0; +} + +static const char *ds4_gpu_mpp_compare_label(const char *fallback, + char *buf, + size_t buflen) { + if (g_mpp_compare_context[0]) return g_mpp_compare_context; + snprintf(buf, buflen, "%s", fallback && fallback[0] ? fallback : "unknown"); + return buf; +} + +static void ds4_gpu_mpp_compare_note_delta( + uint64_t *idx, + float *ref_vals, + float *cand_vals, + float *abs_vals, + uint64_t id, + float ref, + float cand) { + const float abs_delta = fabsf(cand - ref); + for (int i = 0; i < DS4_METAL_MPP_COMPARE_DELTAS; i++) { + if (idx[i] == UINT64_MAX || abs_delta > abs_vals[i]) { + for (int j = DS4_METAL_MPP_COMPARE_DELTAS - 1; j > i; j--) { + idx[j] = idx[j - 1]; + ref_vals[j] = ref_vals[j - 1]; + cand_vals[j] = cand_vals[j - 1]; + abs_vals[j] = abs_vals[j - 1]; + } + idx[i] = id; + ref_vals[i] = ref; + cand_vals[i] = cand; + abs_vals[i] = abs_delta; + return; + } + } +} + +static void ds4_gpu_mpp_compare_clear_pending(void) { + for (int i = 0; i < g_mpp_compare_pending_count; i++) { + g_mpp_compare_pending[i].ref_buffer = nil; + g_mpp_compare_pending[i].cand_buffer = nil; + g_mpp_compare_pending[i].elements = 0; + g_mpp_compare_pending[i].route[0] = '\0'; + g_mpp_compare_pending[i].label[0] = '\0'; + } + g_mpp_compare_pending_count = 0; +} + +static void ds4_gpu_mpp_compare_reset(void) { + ds4_gpu_mpp_compare_clear_pending(); + g_mpp_compare_done_count = 0; + g_mpp_compare_stopped = 0; + g_mpp_compare_limit_reported = 0; +} + +static void ds4_gpu_mpp_compare_drain(const char *finish_label) { + (void)finish_label; + const int max_reports = ds4_gpu_mpp_compare_max(); + for (int i = 0; i < g_mpp_compare_pending_count; i++) { + ds4_gpu_mpp_compare_item *item = &g_mpp_compare_pending[i]; + if (g_mpp_compare_stopped || g_mpp_compare_done_count >= max_reports || + !item->ref_buffer || !item->cand_buffer || item->elements == 0) { + continue; + } + + const float *ref = (const float *)((const uint8_t *)[item->ref_buffer contents] + item->ref_offset); + const float *cand = (const float *)((const uint8_t *)[item->cand_buffer contents] + item->cand_offset); + double sumsq = 0.0; + float max_abs = 0.0f; + uint64_t max_index = 0; + int nonfinite = 0; + uint64_t delta_idx[DS4_METAL_MPP_COMPARE_DELTAS]; + float delta_ref[DS4_METAL_MPP_COMPARE_DELTAS]; + float delta_cand[DS4_METAL_MPP_COMPARE_DELTAS]; + float delta_abs[DS4_METAL_MPP_COMPARE_DELTAS]; + for (int j = 0; j < DS4_METAL_MPP_COMPARE_DELTAS; j++) { + delta_idx[j] = UINT64_MAX; + delta_ref[j] = 0.0f; + delta_cand[j] = 0.0f; + delta_abs[j] = 0.0f; + } + + for (uint64_t j = 0; j < item->elements; j++) { + if (!isfinite(ref[j]) || !isfinite(cand[j])) { + nonfinite++; + continue; + } + const float delta = cand[j] - ref[j]; + const float abs_delta = fabsf(delta); + sumsq += (double)delta * (double)delta; + if (abs_delta > max_abs) { + max_abs = abs_delta; + max_index = j; + } + ds4_gpu_mpp_compare_note_delta(delta_idx, delta_ref, delta_cand, delta_abs, + j, ref[j], cand[j]); + } + + const float rms = (float)sqrt(sumsq / (double)item->elements); + const int exceeds_target = (nonfinite != 0 || max_abs > 1.0e-3f || rms > 1.0e-4f); + if (ds4_gpu_mpp_compare_verbose() || exceeds_target) { + fprintf(stderr, + "ds4: Metal Tensor compare route=%s module=%s shape=%llux%llux%llu max_abs=%g rms=%g nonfinite=%d max_index=%llu\n", + item->route, + item->label, + (unsigned long long)item->dim0, + (unsigned long long)item->dim1, + (unsigned long long)item->dim2, + max_abs, + rms, + nonfinite, + (unsigned long long)max_index); + fprintf(stderr, "ds4: Metal Tensor compare route=%s module=%s largest deltas:", + item->route, item->label); + for (int j = 0; j < DS4_METAL_MPP_COMPARE_DELTAS && delta_idx[j] != UINT64_MAX; j++) { + fprintf(stderr, " idx=%llu ref=%g cand=%g abs=%g", + (unsigned long long)delta_idx[j], + delta_ref[j], + delta_cand[j], + delta_abs[j]); + } + fputc('\n', stderr); + } + + g_mpp_compare_done_count++; + if (exceeds_target) { + fprintf(stderr, + "ds4: Metal Tensor compare route=%s module=%s exceeded target max_abs<=0.001 rms<=0.0001; stopping comparisons\n", + item->route, + item->label); + g_mpp_compare_stopped = 1; + } + } + if (!g_mpp_compare_stopped && !g_mpp_compare_limit_reported && + g_mpp_compare_done_count >= max_reports) { + fprintf(stderr, + "ds4: Metal Tensor compare reached DS4_METAL_MPP_COMPARE_MAX=%d without a target breach\n", + max_reports); + g_mpp_compare_limit_reported = 1; + } + ds4_gpu_mpp_compare_clear_pending(); +} + +static void ds4_gpu_mpp_compare_register( + const char *route, + const char *fallback_label, + const ds4_gpu_tensor *ref, + const ds4_gpu_tensor *cand, + uint64_t elements, + uint64_t dim0, + uint64_t dim1, + uint64_t dim2) { + if (!ds4_gpu_mpp_compare_route_matches(route)) return; + if (g_mpp_compare_done_count + g_mpp_compare_pending_count >= ds4_gpu_mpp_compare_max()) return; + if (g_mpp_compare_pending_count >= DS4_METAL_MPP_COMPARE_PENDING_MAX) return; + id ref_buffer = ds4_gpu_tensor_buffer(ref); + id cand_buffer = ds4_gpu_tensor_buffer(cand); + if (!ref_buffer || !cand_buffer || elements == 0) return; + + ds4_gpu_mpp_compare_item *item = &g_mpp_compare_pending[g_mpp_compare_pending_count++]; + item->ref_buffer = nil; + item->cand_buffer = nil; + item->ref_offset = 0; + item->cand_offset = 0; + item->elements = 0; + item->dim0 = 0; + item->dim1 = 0; + item->dim2 = 0; + item->route[0] = '\0'; + item->label[0] = '\0'; + item->ref_buffer = ref_buffer; + item->cand_buffer = cand_buffer; + item->ref_offset = ds4_gpu_tensor_offset(ref); + item->cand_offset = ds4_gpu_tensor_offset(cand); + item->elements = elements; + item->dim0 = dim0; + item->dim1 = dim1; + item->dim2 = dim2; + snprintf(item->route, sizeof(item->route), "%s", route); + char label_buf[128]; + snprintf(item->label, sizeof(item->label), "%s", + ds4_gpu_mpp_compare_label(fallback_label, label_buf, sizeof(label_buf))); +} + +static ds4_gpu_tensor *ds4_gpu_mpp_compare_make_buffer_view( + id buffer, + NSUInteger offset, + uint64_t bytes) { + if (!buffer || bytes > (uint64_t)NSUIntegerMax) return NULL; + DS4MetalTensor *view = [DS4MetalTensor new]; + view.buffer = buffer; + view.offset = (uint64_t)offset; + view.bytes = bytes; + view.owner = 0; + return (__bridge_retained ds4_gpu_tensor *)view; +} + +static ds4_gpu_tensor *ds4_gpu_mpp_compare_snapshot_buffer( + id buffer, + NSUInteger offset, + uint64_t bytes) { + ds4_gpu_tensor *view = ds4_gpu_mpp_compare_make_buffer_view(buffer, offset, bytes); + ds4_gpu_tensor *snapshot = ds4_gpu_tensor_alloc(bytes); + if (!view || !snapshot) { + ds4_gpu_tensor_free(view); + ds4_gpu_tensor_free(snapshot); + return NULL; + } + + int ok = 0; + if (g_batch_cb) { + ok = ds4_gpu_tensor_copy(snapshot, 0, view, 0, bytes); + } else { + memcpy(ds4_gpu_tensor_contents(snapshot), + (const uint8_t *)[buffer contents] + offset, + (size_t)bytes); + ok = 1; + } + ds4_gpu_tensor_free(view); + if (!ok) { + ds4_gpu_tensor_free(snapshot); + return NULL; + } + return snapshot; +} + static int ds4_gpu_finish_command_buffer(id cb, int owned, const char *label) { if (!owned) return 1; [cb commit]; int ok = ds4_gpu_wait_pending_command_buffers(label); if (!ds4_gpu_wait_command_buffer(cb, label)) ok = 0; + if (ok) ds4_gpu_mpp_compare_drain(label); [g_transient_buffers removeAllObjects]; return ok; } @@ -589,14 +876,16 @@ static int ds4_gpu_map_model_views( static id ds4_gpu_get_mul_mm_id_pipeline( const char *function_name, - bool bc_inp) { - NSString *key = [NSString stringWithFormat:@"%s_bci=%d", - function_name, bc_inp ? 1 : 0]; + bool bc_inp, + bool use_mpp) { + NSString *key = [NSString stringWithFormat:@"%s_bci=%d_mpp=%d", + function_name, bc_inp ? 1 : 0, use_mpp ? 1 : 0]; id cached = [g_pipeline_cache objectForKey:key]; if (cached) return cached; MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init]; [constants setConstantValue:&bc_inp type:MTLDataTypeBool atIndex:700]; + [constants setConstantValue:&use_mpp type:MTLDataTypeBool atIndex:702]; NSError *error = nil; NSString *name = [NSString stringWithUTF8String:function_name]; @@ -673,6 +962,702 @@ static int ds4_gpu_use_compressor_pair_nr4(void) { return enabled; } +static int ds4_gpu_device_name_contains(const char *needle); + +static int ds4_gpu_mpp_q8_0_default_target(void) { + // The Metal 4 cooperative-tensor Q8_0 matmul on M5 Max produces logprob + // drift versus the legacy simdgroup_multiply_accumulate path (measured + // rms=0.150, max_abs=0.75 on the short reasoning prompt; bit-exact match + // recovered by disabling just this route). The other Tensor routes + // (F16 compressor, attention-output, MoE) are bit-clean. Default the + // Q8_0 Tensor matmul to OFF on M5; opt back in with DS4_METAL_MPP_Q8_0_ENABLE=1. + if (ds4_gpu_device_name_contains("M5")) return 0; + return 1; +} + +// F16 compressor Tensor matmul default. Bit-clean on M5 vs the legacy +// simdgroup path, so this stays default-on independent of device. +// Kept as a separate helper to avoid coupling the F16 default to the +// Q8_0 carve-out above. +static int ds4_gpu_mpp_f16_default_target(void) { + return 1; +} + +static int ds4_gpu_env_value_eq(const char *v, size_t n, const char *literal) { + size_t m = strlen(literal); + if (n != m) return 0; + for (size_t i = 0; i < n; i++) { + if (tolower((unsigned char)v[i]) != tolower((unsigned char)literal[i])) return 0; + } + return 1; +} + +static int ds4_gpu_env_bool(const char *name) { + const char *v = getenv(name); + if (!v) return -1; + + while (isspace((unsigned char)*v)) v++; + size_t n = strlen(v); + while (n > 0 && isspace((unsigned char)v[n - 1])) n--; + if (n == 0) return 1; + + if (ds4_gpu_env_value_eq(v, n, "1") || + ds4_gpu_env_value_eq(v, n, "true") || + ds4_gpu_env_value_eq(v, n, "yes") || + ds4_gpu_env_value_eq(v, n, "on")) { + return 1; + } + if (ds4_gpu_env_value_eq(v, n, "0") || + ds4_gpu_env_value_eq(v, n, "false") || + ds4_gpu_env_value_eq(v, n, "no") || + ds4_gpu_env_value_eq(v, n, "off")) { + return 0; + } + + if (!g_mpp_invalid_env_reported) { + fprintf(stderr, + "ds4: invalid Metal Tensor boolean environment value %s=%.*s; treating presence as enabled\n", + name, (int)n, v); + g_mpp_invalid_env_reported = 1; + } + return 1; +} + +static int ds4_gpu_mpp_low_power_profile(void) { + const int disabled = ds4_gpu_env_bool("DS4_METAL_MPP_LOW_POWER_DISABLE"); + if (disabled > 0) return 0; + + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_LOW_POWER_ENABLE"); + if (enabled >= 0) return enabled > 0; + + static int detected = -1; + static int reported; + if (detected < 0) { + detected = 0; + @autoreleasepool { + NSProcessInfo *info = [NSProcessInfo processInfo]; + if ([info respondsToSelector:@selector(isLowPowerModeEnabled)]) { + detected = [info isLowPowerModeEnabled] ? 1 : 0; + } + } + } + if (detected && !reported) { + fprintf(stderr, + "ds4: Metal low-power Tensor profile active; widening Q8_0 prefill route\n"); + reported = 1; + } + return detected; +} + +static int ds4_gpu_use_indexed_attention_rb4(void) { + static int enabled = -1; + if (enabled < 0) { + enabled = ds4_gpu_env_bool("DS4_METAL_INDEXED_ATTN_RB4") > 0; + } + return enabled; +} + +typedef enum { + DS4_METAL_MPP_GLOBAL_OFF, + DS4_METAL_MPP_GLOBAL_AUTO, + DS4_METAL_MPP_GLOBAL_ON, +} ds4_gpu_mpp_global_policy; + +static ds4_gpu_mpp_global_policy ds4_gpu_mpp_global_policy_mode(void) { + if (!g_metal4_tensor_api_enabled || g_quality_mode) return DS4_METAL_MPP_GLOBAL_OFF; + if (g_mpp_mode == DS4_MPP_OFF) return DS4_METAL_MPP_GLOBAL_OFF; + if (g_mpp_mode == DS4_MPP_ON) return DS4_METAL_MPP_GLOBAL_ON; + + const int disabled = ds4_gpu_env_bool("DS4_METAL_MPP_DISABLE"); + if (disabled > 0) return DS4_METAL_MPP_GLOBAL_OFF; + + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_ENABLE"); + if (enabled >= 0) return enabled ? DS4_METAL_MPP_GLOBAL_ON : DS4_METAL_MPP_GLOBAL_OFF; + + return DS4_METAL_MPP_GLOBAL_AUTO; +} + +static int ds4_gpu_mpp_route_switch(const char *enable_env, const char *disable_env) { + const int disabled = ds4_gpu_env_bool(disable_env); + if (disabled > 0) return 0; + + const int enabled = ds4_gpu_env_bool(enable_env); + if (enabled >= 0) return enabled ? 1 : 0; + + return -1; +} + +static int ds4_gpu_mpp_route_enabled( + int default_target, + const char *enable_env, + const char *disable_env) { + const ds4_gpu_mpp_global_policy policy = ds4_gpu_mpp_global_policy_mode(); + if (policy == DS4_METAL_MPP_GLOBAL_OFF) return 0; + + const int route = ds4_gpu_mpp_route_switch(enable_env, disable_env); + if (route >= 0) return route; + + if (policy == DS4_METAL_MPP_GLOBAL_ON) return 1; + return default_target; +} + +static int ds4_gpu_mpp_fast_profile(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_FAST") > 0; +} + +static const char *ds4_gpu_mpp_enabled_reason(void) { + if (g_mpp_mode == DS4_MPP_ON) return " by -mt on"; + if (ds4_gpu_mpp_fast_profile()) return " by DS4_METAL_MPP_FAST"; + if (ds4_gpu_env_bool("DS4_METAL_MPP_ENABLE") > 0) return " by DS4_METAL_MPP_ENABLE"; + return " by default"; +} + +static int ds4_gpu_mpp_q8_0_policy_enabled(void) { + return ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_q8_0_default_target(), + "DS4_METAL_MPP_Q8_0_ENABLE", + "DS4_METAL_MPP_Q8_0_DISABLE"); +} + +static int ds4_gpu_use_mpp_q8_0_matmul(void) { + const int enabled = ds4_gpu_mpp_q8_0_policy_enabled(); + if (enabled && !g_mpp_q8_reported) { + fprintf(stderr, "ds4: Metal Tensor Q8_0 prefill matmul enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_q8_reported = 1; + } + return enabled; +} + +static int ds4_gpu_mpp_q8_0_partial_tiles_enabled(void) { + if (ds4_gpu_mpp_fast_profile()) return 1; + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE"); + if (enabled >= 0) return enabled > 0; + return 1; +} + +static uint32_t ds4_gpu_mpp_tile_n_env(const char *name, uint32_t fallback) { + const char *env = getenv(name); + if (!env || !env[0]) return fallback; + char *end = NULL; + long v = strtol(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end && *end == '\0' && v == 64) return 64; + if (end && *end == '\0' && v == 32) return 32; + fprintf(stderr, + "ds4: invalid %s=%s; expected 32 or 64, using %u\n", + name, env, fallback); + return fallback; +} + +static uint32_t ds4_gpu_mpp_q8_0_tile_n(void) { + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_Q8_0_TILE_N", 64); +} + +static uint32_t ds4_gpu_mpp_q8_0_tile_n_for_tokens(uint64_t n_tok) { + const char *env = getenv("DS4_METAL_MPP_Q8_0_TILE_N"); + if (env && env[0]) return ds4_gpu_mpp_q8_0_tile_n(); + return n_tok >= 4096u ? 32u : 64u; +} + +static uint32_t ds4_gpu_mpp_attn_out_tile_n(void) { + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_ATTN_OUT_TILE_N", 64); +} + +static uint32_t ds4_gpu_mpp_moe_tile_n(void) { + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_MOE_TILE_N", 32); +} + +static int ds4_gpu_mpp_moe_fast_layout(void) { + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_MOE_FAST_LAYOUT"); + if (enabled >= 0) return enabled > 0; + return 1; +} + +static int ds4_gpu_mpp_moe_pair_gate_up(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_MOE_PAIR_GATE_UP") > 0; +} + +static int ds4_gpu_mpp_direct_rhs(void) { + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_DIRECT_RHS"); + if (enabled >= 0) return enabled > 0; + return 1; +} + +static int ds4_gpu_mpp_q8_0_direct_rhs(void) { + return ds4_gpu_mpp_direct_rhs() || + ds4_gpu_env_bool("DS4_METAL_MPP_Q8_0_DIRECT_RHS") > 0; +} + +static int ds4_gpu_mpp_f16_direct_rhs(void) { + return ds4_gpu_mpp_direct_rhs() || + ds4_gpu_env_bool("DS4_METAL_MPP_F16_DIRECT_RHS") > 0; +} + +static int ds4_gpu_mpp_f16_wide_matmul(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_F16_WIDE") > 0; +} + +static int ds4_gpu_mpp_f16_pair_matmul(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_F16_PAIR") > 0; +} + +static int ds4_gpu_mpp_attn_out_direct_rhs(void) { + return ds4_gpu_mpp_direct_rhs() || + ds4_gpu_env_bool("DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS") > 0; +} + +static int ds4_gpu_mpp_layer_env(const char *name, int fallback) { + const char *env = getenv(name); + if (!env || !env[0]) return fallback; + char *end = NULL; + long v = strtol(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end && *end == '\0' && v >= 0 && v <= 255) return (int)v; + fprintf(stderr, + "ds4: invalid %s=%s; expected layer index 0..255, using %d\n", + name, env, fallback); + return fallback; +} + +static int ds4_gpu_mpp_context_layer(void) { + if (!g_mpp_compare_context[0]) return -1; + int layer = -1; + if (sscanf(g_mpp_compare_context, "layer=%d", &layer) == 1) return layer; + return -1; +} + +static int ds4_gpu_mpp_late_safe_context_range(int first_layer) { + const int layer = ds4_gpu_mpp_context_layer(); + return layer >= first_layer && layer <= 42; +} + +static int ds4_gpu_mpp_q8_0_late_safe_context(void) { + const int layer = ds4_gpu_mpp_context_layer(); + if (layer >= 38 && layer <= 42) return 1; + if (layer >= 32 && layer <= 37 && + strstr(g_mpp_compare_context, "attn_q_b") != NULL) { + return 1; + } + return 0; +} + +static int ds4_gpu_mpp_attn_out_late_safe_context(void) { + return ds4_gpu_mpp_late_safe_context_range(32); +} + +static int ds4_gpu_mpp_layer_expr_matches(const char *layer_expr) { + if (!layer_expr || !*layer_expr) return 0; + const int layer = ds4_gpu_mpp_context_layer(); + char *parse_end = NULL; + long first = strtol(layer_expr, &parse_end, 10); + while (parse_end && isspace((unsigned char)*parse_end)) parse_end++; + if (!parse_end || parse_end == layer_expr || + first < 0 || first > 255 || + !(parse_end[0] == '\0' || + (parse_end[0] == '-' && parse_end[1] != '\0') || + (parse_end[0] == '.' && parse_end[1] == '.' && parse_end[2] != '\0'))) { + return 0; + } + + long last = first; + if (parse_end[0] == '-') { + const char *range_end = parse_end + 1; + while (isspace((unsigned char)*range_end)) range_end++; + char *end2 = NULL; + last = strtol(range_end, &end2, 10); + while (end2 && isspace((unsigned char)*end2)) end2++; + if (!end2 || end2 == range_end || *end2 != '\0') return 0; + } else if (parse_end[0] == '.') { + const char *range_end = parse_end + 2; + while (isspace((unsigned char)*range_end)) range_end++; + char *end2 = NULL; + last = strtol(range_end, &end2, 10); + while (end2 && isspace((unsigned char)*end2)) end2++; + if (!end2 || end2 == range_end || *end2 != '\0') return 0; + } + if (last < first || last < 0 || last > 255) return 0; + return layer >= first && layer <= last; +} + +static int ds4_gpu_mpp_context_matches_filter( + const char *env_name, + int default_match, + int late_safe_match) { + const char *filter = getenv(env_name); + if (!filter || !filter[0]) return default_match; + if (!g_mpp_compare_context[0]) return 0; + + const char *p = filter; + while (*p) { + while (*p == ',' || isspace((unsigned char)*p)) p++; + const char *start = p; + while (*p && *p != ',') p++; + const char *end = p; + while (end > start && isspace((unsigned char)end[-1])) end--; + if (end > start) { + char token[64]; + size_t n = (size_t)(end - start); + if (n >= sizeof(token)) n = sizeof(token) - 1u; + memcpy(token, start, n); + token[n] = '\0'; + if (ds4_gpu_env_value_eq(token, n, "all")) return 1; + if (ds4_gpu_env_value_eq(token, n, "none")) return 0; + if (ds4_gpu_env_value_eq(token, n, "late_safe")) return late_safe_match; + char *at = strchr(token, '@'); + if (at) { + *at = '\0'; + const char *module = token; + const char *expr = at + 1; + if (strncmp(expr, "layer=", 6) == 0) { + expr += 6; + } else if (strncmp(expr, "layer:", 6) == 0) { + expr += 6; + } else { + continue; + } + if (*module && + strstr(g_mpp_compare_context, module) != NULL && + ds4_gpu_mpp_layer_expr_matches(expr)) { + return 1; + } + continue; + } + const char *layer_expr = NULL; + if (strncmp(token, "layer=", 6) == 0) { + layer_expr = token + 6; + } else if (strncmp(token, "layer:", 6) == 0) { + layer_expr = token + 6; + } + if (layer_expr && *layer_expr) { + if (ds4_gpu_mpp_layer_expr_matches(layer_expr)) return 1; + continue; + } + if (strstr(g_mpp_compare_context, token) != NULL) return 1; + } + } + return 0; +} + +static int ds4_gpu_mpp_q8_0_context_matches_filter(uint64_t n_tok) { + (void)n_tok; + const char *filter = getenv("DS4_METAL_MPP_Q8_0_FILTER"); + const int filter_set = filter && filter[0]; + const int default_match = + (ds4_gpu_mpp_fast_profile() || + (!filter_set && ds4_gpu_mpp_low_power_profile())) + ? 1 + : ds4_gpu_mpp_q8_0_late_safe_context(); + return ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_Q8_0_FILTER", + default_match, + ds4_gpu_mpp_q8_0_late_safe_context()); +} + +static int ds4_gpu_can_use_mpp_q8_0_matmul(uint64_t n_tok) { + if (n_tok <= 8) return 0; + if (!ds4_gpu_use_mpp_q8_0_matmul()) return 0; + if (!ds4_gpu_mpp_q8_0_context_matches_filter(n_tok)) return 0; + if ((n_tok % 32u) == 0 || ds4_gpu_mpp_q8_0_partial_tiles_enabled()) return 1; + + if (!g_mpp_q8_partial_skip_reported) { + fprintf(stderr, + "ds4: Metal Tensor Q8_0 prefill matmul skipping partial token tiles; " + "set DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=1 to test them\n"); + g_mpp_q8_partial_skip_reported = 1; + } + return 0; +} + +static int ds4_gpu_use_mpp_f16_compressor_matmul(void) { + const int enabled = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_f16_default_target(), + "DS4_METAL_MPP_F16_ENABLE", + "DS4_METAL_MPP_F16_DISABLE"); + if (enabled && !g_mpp_f16_reported) { + fprintf(stderr, "ds4: Metal Tensor F16 compressor prefill matmul enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_f16_reported = 1; + } + return enabled; +} + +static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { + const int default_match = ds4_gpu_mpp_fast_profile() + ? 1 + : ds4_gpu_mpp_attn_out_late_safe_context(); + const int enabled = + ds4_gpu_mpp_route_enabled(1, + "DS4_METAL_MPP_ATTN_OUT_ENABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE") && + ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_ATTN_OUT_FILTER", + default_match, + ds4_gpu_mpp_attn_out_late_safe_context()); + if (enabled && !g_mpp_attn_out_reported) { + fprintf(stderr, "ds4: Metal Tensor attention-output low projection enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_attn_out_reported = 1; + } + return enabled; +} + +enum { + DS4_METAL_MOE_MPP_GATE = 1 << 0, + DS4_METAL_MOE_MPP_UP = 1 << 1, + DS4_METAL_MOE_MPP_DOWN = 1 << 2, + + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 20, + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 20, + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 22, + DS4_METAL_MOE_MPP_FAST_GATE_LAYER = 0, + DS4_METAL_MOE_MPP_FAST_UP_LAYER = 0, + DS4_METAL_MOE_MPP_FAST_DOWN_LAYER = 0, +}; + +static int ds4_gpu_mpp_routed_moe_default_target(void) { + return 1; +} + +static int ds4_gpu_mpp_routed_moe_default_policy(void) { + const ds4_gpu_mpp_global_policy policy = ds4_gpu_mpp_global_policy_mode(); + if (policy == DS4_METAL_MPP_GLOBAL_OFF) return 0; + if (policy == DS4_METAL_MPP_GLOBAL_ON) return 1; + + const int group = ds4_gpu_mpp_route_switch("DS4_METAL_MPP_MOE_ENABLE", + "DS4_METAL_MPP_MOE_DISABLE"); + if (group >= 0) return group; + + return ds4_gpu_mpp_routed_moe_default_target(); +} + +static int ds4_gpu_mpp_moe_route_enabled(const char *enable_env, const char *disable_env) { + const ds4_gpu_mpp_global_policy policy = ds4_gpu_mpp_global_policy_mode(); + if (policy == DS4_METAL_MPP_GLOBAL_OFF) return 0; + + const int group = ds4_gpu_mpp_route_switch("DS4_METAL_MPP_MOE_ENABLE", + "DS4_METAL_MPP_MOE_DISABLE"); + if (group == 0) return 0; + + const int route = ds4_gpu_mpp_route_switch(enable_env, disable_env); + if (route >= 0) return route; + + if (group == 1 || policy == DS4_METAL_MPP_GLOBAL_ON) return 1; + return ds4_gpu_mpp_routed_moe_default_target(); +} + +static int ds4_gpu_mpp_routed_moe_stage_mask(void) { + int mask = 0; + if (ds4_gpu_mpp_moe_route_enabled("DS4_METAL_MPP_MOE_GATE_ENABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE")) { + mask |= DS4_METAL_MOE_MPP_GATE; + } + if (ds4_gpu_mpp_moe_route_enabled("DS4_METAL_MPP_MOE_UP_ENABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE")) { + mask |= DS4_METAL_MOE_MPP_UP; + } + if (ds4_gpu_mpp_moe_route_enabled("DS4_METAL_MPP_MOE_DOWN_ENABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE")) { + mask |= DS4_METAL_MOE_MPP_DOWN; + } + if (mask && !g_mpp_moe_reported) { + fprintf(stderr, "ds4: Metal Tensor routed MoE projections enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_moe_reported = 1; + } + return mask; +} + +static int ds4_gpu_mpp_moe_late_safe_context(int first_layer) { + return ds4_gpu_mpp_late_safe_context_range(first_layer); +} + +static int ds4_gpu_mpp_moe_context_matches_filter(const char *route_filter_env, + int first_layer) { + return ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_MOE_FILTER", + 1, + ds4_gpu_mpp_moe_late_safe_context(first_layer)) && + ds4_gpu_mpp_context_matches_filter(route_filter_env, + 1, + ds4_gpu_mpp_moe_late_safe_context(first_layer)); +} + +static int ds4_gpu_mpp_moe_start_layer(const char *route_env, int fallback) { + const int common = ds4_gpu_mpp_layer_env("DS4_METAL_MPP_MOE_START_LAYER", fallback); + return ds4_gpu_mpp_layer_env(route_env, common); +} + +static int ds4_gpu_mpp_routed_moe_mask_for_layer(uint32_t layer_index) { + const int requested_mask = ds4_gpu_mpp_routed_moe_stage_mask(); + if (!requested_mask) return 0; + + if (ds4_gpu_mpp_routed_moe_default_policy()) { + const int fast_profile = ds4_gpu_mpp_fast_profile(); + const int down_fallback = fast_profile ? + DS4_METAL_MOE_MPP_FAST_DOWN_LAYER : + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER; + const int up_fallback = fast_profile ? + DS4_METAL_MOE_MPP_FAST_UP_LAYER : + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER; + const int gate_fallback = fast_profile ? + DS4_METAL_MOE_MPP_FAST_GATE_LAYER : + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER; + const int down_start = ds4_gpu_mpp_moe_start_layer( + "DS4_METAL_MPP_MOE_DOWN_START_LAYER", + down_fallback); + const int up_start = ds4_gpu_mpp_moe_start_layer( + "DS4_METAL_MPP_MOE_UP_START_LAYER", + up_fallback); + const int gate_start = ds4_gpu_mpp_moe_start_layer( + "DS4_METAL_MPP_MOE_GATE_START_LAYER", + gate_fallback); + if (!g_mpp_moe_ranges_reported) { + fprintf(stderr, + "ds4: Metal Tensor routed MoE default ranges down=%d..end up=%d..end gate=%d..end\n", + down_start, + up_start, + gate_start); + g_mpp_moe_ranges_reported = 1; + } + int mask = 0; + if ((int)layer_index >= down_start) mask |= DS4_METAL_MOE_MPP_DOWN; + if ((int)layer_index >= up_start) mask |= DS4_METAL_MOE_MPP_UP; + if ((int)layer_index >= gate_start) mask |= DS4_METAL_MOE_MPP_GATE; + if ((mask & DS4_METAL_MOE_MPP_DOWN) && + !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_DOWN_FILTER", + down_start)) { + mask &= ~DS4_METAL_MOE_MPP_DOWN; + } + if ((mask & DS4_METAL_MOE_MPP_UP) && + !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_UP_FILTER", + up_start)) { + mask &= ~DS4_METAL_MOE_MPP_UP; + } + if ((mask & DS4_METAL_MOE_MPP_GATE) && + !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_GATE_FILTER", + gate_start)) { + mask &= ~DS4_METAL_MOE_MPP_GATE; + } + return mask & requested_mask; + } + + return 0; +} + +static void ds4_gpu_warn_mpp_fallback(void) { + static int warned; + if (!warned) { + fprintf(stderr, "ds4: Metal Tensor prefill matmul unavailable; falling back to legacy kernel\n"); + warned = 1; + } +} + +static int ds4_gpu_device_name_contains(const char *needle) { + return g_metal_device_name[0] != '\0' && strstr(g_metal_device_name, needle) != NULL; +} + +static int ds4_gpu_compile_tensor_probe(void) { +#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000 + if (!g_device) return 0; + if (@available(macOS 26.0, *)) { + const char *src = + "#include \n" + "#include \n" + "#include \n" + "using namespace metal;\n" + "using namespace mpp::tensor_ops;\n" + "kernel void ds4_tensor_probe(\n" + " tensor> A [[buffer(0)]],\n" + " tensor> B [[buffer(1)]],\n" + " device float *C [[buffer(2)]],\n" + " uint2 tgid [[threadgroup_position_in_grid]]) {\n" + " auto tA = A.slice(0, (int)tgid.y);\n" + " auto tB = B.slice((int)tgid.x, 0);\n" + " matmul2d> mm;\n" + " auto cT = mm.get_destination_cooperative_tensor();\n" + " auto sA = tA.slice(0, 0);\n" + " auto sB = tB.slice(0, 0);\n" + " mm.run(sB, sA, cT);\n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16));\n" + " cT.store(tC);\n" + "}\n"; + + NSError *error = nil; + NSString *source = [NSString stringWithUTF8String:src]; + id probe_library = [g_device newLibraryWithSource:source options:[MTLCompileOptions new] error:&error]; + if (!probe_library) { + fprintf(stderr, "ds4: Metal 4 tensor API probe compile failed: %s\n", + error ? [[error localizedDescription] UTF8String] : "(unknown)"); + return 0; + } + id fn = [probe_library newFunctionWithName:@"ds4_tensor_probe"]; + if (!fn) { + fprintf(stderr, "ds4: Metal 4 tensor API probe function missing\n"); + return 0; + } + error = nil; + id pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!pipeline) { + fprintf(stderr, "ds4: Metal 4 tensor API probe pipeline failed: %s\n", + error ? [[error localizedDescription] UTF8String] : "(unknown)"); + return 0; + } + return 1; + } +#endif + return 0; +} + +static void ds4_gpu_detect_metal4_features(void) { + g_metal4_runtime_available = 0; + g_metal4_family_supported = 0; + g_metal4_queue_supported = 0; + g_metal4_m5_neural_accelerators_hint = 0; + g_metal4_tensor_api_enabled = 0; + g_metal4_tensor_api_compile_supported = 0; + g_metal_device_name[0] = '\0'; + + if (!g_device) return; + + const char *name = [[g_device name] UTF8String]; + if (name) { + snprintf(g_metal_device_name, sizeof(g_metal_device_name), "%s", name); + } + +#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000 + if (@available(macOS 26.0, *)) { + g_metal4_runtime_available = 1; + g_metal4_family_supported = [g_device supportsFamily:MTLGPUFamilyMetal4] ? 1 : 0; + g_metal4_queue_supported = [g_device respondsToSelector:@selector(newMTL4CommandQueue)] ? 1 : 0; + + /* + * Apple does not currently expose a separate "Neural Accelerator" bit + * through Metal. On public M5 systems the hardware signal is the device + * generation plus Metal 4 support, so keep this as a conservative hint + * for diagnostics and future opt-in MPP/tensor kernels. + */ + if (g_metal4_family_supported && ds4_gpu_device_name_contains("M5")) { + g_metal4_m5_neural_accelerators_hint = 1; + } + + if (g_metal4_family_supported && getenv("DS4_METAL_TENSOR_DISABLE") == NULL) { + const int explicit_enable = getenv("DS4_METAL_TENSOR_ENABLE") != NULL; + const int default_enable = + ds4_gpu_device_name_contains("M5") || + ds4_gpu_device_name_contains("M6") || + ds4_gpu_device_name_contains("A19") || + ds4_gpu_device_name_contains("A20"); + + if (explicit_enable || default_enable) { + g_metal4_tensor_api_compile_supported = ds4_gpu_compile_tensor_probe(); + g_metal4_tensor_api_enabled = g_metal4_tensor_api_compile_supported; + if (!g_metal4_tensor_api_enabled) { + fprintf(stderr, "ds4: Metal 4 tensor API probe failed; using legacy Metal kernels\n"); + } + } else { + fprintf(stderr, "ds4: Metal 4 tensor API disabled for pre-M5/pre-A19 devices (set DS4_METAL_TENSOR_ENABLE=1 to experiment)\n"); + } + } + } +#endif +} + static int ds4_gpu_warm_model_views(void) { if (g_model_view_count == 0) return 1; @@ -1112,6 +2097,36 @@ void ds4_gpu_print_memory_report(const char *label) { "ds4: model residency requests %llu%s\n", (unsigned long long)g_model_residency_count, getenv("DS4_METAL_NO_RESIDENCY") != NULL ? " (disabled)" : ""); + fprintf(stderr, + "ds4: device %s, Metal 4 runtime %s, family %s, MTL4 queue %s, tensor API %s, M5 neural accelerators %s\n", + g_metal_device_name[0] ? g_metal_device_name : "(unknown)", + g_metal4_runtime_available ? "yes" : "no", + g_metal4_family_supported ? "yes" : "no", + g_metal4_queue_supported ? "yes" : "no", + g_metal4_tensor_api_enabled ? "enabled" : + (g_metal4_tensor_api_compile_supported ? "available" : "disabled"), + g_metal4_m5_neural_accelerators_hint ? "likely" : "not detected"); + const int mpp_q8 = ds4_gpu_mpp_q8_0_policy_enabled(); + const int mpp_f16 = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_f16_default_target(), + "DS4_METAL_MPP_F16_ENABLE", + "DS4_METAL_MPP_F16_DISABLE"); + const int mpp_attn_out = ds4_gpu_mpp_route_enabled(0, + "DS4_METAL_MPP_ATTN_OUT_ENABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE"); + const int mpp_moe = ds4_gpu_mpp_routed_moe_stage_mask(); + fprintf(stderr, + "ds4: Metal Tensor policy %s%s%s\n", + ds4_mpp_mode_name(g_mpp_mode), + g_quality_mode ? " (disabled by --quality)" : "", + !g_metal4_tensor_api_enabled ? " (tensor API unavailable)" : ""); + fprintf(stderr, + "ds4: Metal Tensor routes q8_0=%s f16_compressor=%s attn_out=%s moe_gate=%s moe_up=%s moe_down=%s\n", + mpp_q8 ? "on" : "off", + mpp_f16 ? "on" : "off", + mpp_attn_out ? "on" : "off", + (mpp_moe & DS4_METAL_MOE_MPP_GATE) ? "on" : "off", + (mpp_moe & DS4_METAL_MOE_MPP_UP) ? "on" : "off", + (mpp_moe & DS4_METAL_MOE_MPP_DOWN) ? "on" : "off"); fprintf(stderr, "ds4: scratch %.2f MiB (flash mask %.2f, pad %.2f, tmp %.2f, blk %.2f, ring %.2f, kv %.2f, compressor %.2f, router %.2f, indexer %.2f, moe %.2f, f16 %.2f, raw-store %.2f)\n", ds4_gpu_mib(scratch), @@ -1141,8 +2156,47 @@ void ds4_gpu_print_memory_report(const char *label) { ds4_gpu_mib((uint64_t)g_raw_store_round_bytes)); } +static void ds4_gpu_mpp_reset_reports(void) { + g_mpp_q8_reported = 0; + g_mpp_q8_partial_skip_reported = 0; + g_mpp_f16_reported = 0; + g_mpp_f16_pair_reported = 0; + g_mpp_attn_out_reported = 0; + g_mpp_moe_reported = 0; + g_mpp_moe_ranges_reported = 0; +} + void ds4_gpu_set_quality(bool quality) { - g_quality_mode = quality ? 1 : 0; + const int next = quality ? 1 : 0; + if (g_quality_mode != next) { + ds4_gpu_mpp_reset_reports(); + ds4_gpu_mpp_compare_reset(); + } + g_quality_mode = next; +} + +void ds4_gpu_set_mpp_mode(ds4_mpp_mode mode) { + if (mode != DS4_MPP_AUTO && mode != DS4_MPP_ON && mode != DS4_MPP_OFF) { + mode = DS4_MPP_AUTO; + } + if (g_mpp_mode != mode) { + ds4_gpu_mpp_reset_reports(); + ds4_gpu_mpp_compare_reset(); + } + g_mpp_mode = mode; +} + +void ds4_gpu_set_mpp_compare_context(const char *module, uint32_t layer_index, uint32_t pos0) { + if (!module || !module[0]) { + g_mpp_compare_context[0] = '\0'; + return; + } + snprintf(g_mpp_compare_context, sizeof(g_mpp_compare_context), + "layer=%u pos=%u %s", layer_index, pos0, module); +} + +void ds4_gpu_clear_mpp_compare_context(void) { + g_mpp_compare_context[0] = '\0'; } static id ds4_gpu_wrap_model_range( @@ -1154,7 +2208,14 @@ void ds4_gpu_set_quality(bool quality) { static const char *ds4_gpu_source = "#include \n" +"#ifdef DS4_METAL_HAS_TENSOR\n" +"#include \n" +"#include \n" +"#endif\n" "using namespace metal;\n" +"#ifdef DS4_METAL_HAS_TENSOR\n" +"using namespace mpp::tensor_ops;\n" +"#endif\n" "\n" "#define MAX(x, y) ((x) > (y) ? (x) : (y))\n" "#define MIN(x, y) ((x) < (y) ? (x) : (y))\n" @@ -2191,6 +3252,17 @@ static int ds4_gpu_encode_attn_out_low_q8_direct( NSUInteger threadgroup_bytes, NSUInteger nsg); +static int ds4_gpu_encode_attn_out_low_q8_mpp( + id cb, + id pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off); + static ds4_gpu_mul_mm_id_map_args ds4_gpu_make_mul_mm_id_map_args( uint32_t src0_cols, uint32_t src0_experts, @@ -2251,6 +3323,17 @@ static int ds4_gpu_encode_mul_mm_id_mapped( NSUInteger src1_off, id dst, NSUInteger dst_off); +static int ds4_gpu_encode_mul_mm_id_mapped_tile( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off, + uint32_t tile_n); typedef struct { int32_t ne11; @@ -2654,6 +3737,13 @@ static int ds4_gpu_encode_rope_tail_inplace( float clamp_value; } ds4_gpu_dsv4_moe_swiglu_weight_args; +typedef struct { + uint32_t width; + uint32_t tokens; + uint64_t src_token_stride; + uint64_t dst_token_stride; +} ds4_gpu_dsv4_moe_sum6_args; + /* Compile the single in-repo Metal source and create the pipelines that every * session uses. Shape-dependent kernels with function constants are built * lazily by the small ds4_gpu_get_* caches, so startup stays predictable @@ -2668,6 +3758,7 @@ int ds4_gpu_init(void) { return 0; } ds4_gpu_print_device_summary(); + ds4_gpu_detect_metal4_features(); g_queue = [g_device newCommandQueue]; if (!g_queue) { @@ -2698,6 +3789,65 @@ int ds4_gpu_init(void) { return 0; } MTLCompileOptions *options = [MTLCompileOptions new]; + NSMutableDictionary *macros = [NSMutableDictionary new]; + if (g_metal4_tensor_api_enabled) { + macros[@"DS4_METAL_HAS_TENSOR"] = @"1"; + fprintf(stderr, "ds4: Metal 4 tensor API enabled for Tensor kernels\n"); + } + + const int drift_hc_stable = ds4_gpu_env_bool("DS4_METAL_HC_STABLE") != 0; // default ON + const int drift_norm_unify = ds4_gpu_env_bool("DS4_METAL_NORM_RSQRT_DISABLE") != 0; // default ON + const int drift_kv_raw_f32 = ds4_gpu_env_bool("DS4_METAL_KV_RAW_F32") > 0; // default OFF + const int drift_rope_exp2_log2 = ds4_gpu_env_bool("DS4_METAL_ROPE_EXP2_LOG2") > 0; // default OFF + const int drift_math_safe = ds4_gpu_env_bool("DS4_METAL_MATH_SAFE") > 0; // default OFF + const int drift_tensor_matmul_off = g_metal4_tensor_api_enabled && + ds4_gpu_env_bool("DS4_METAL_TENSOR_MATMUL_DISABLE") > 0; + + if (drift_math_safe) { + // MTLCompileOptions.fastMathEnabled defaults to YES and Apple's + // headers explicitly say this "may violate the IEEE 754 standard". + // Different fast-math optimizations get applied across the + // matmul2d cooperative-tensor path and the legacy + // simdgroup_multiply_accumulate path on M5, amplifying the + // mismatch. MTLMathModeSafe pins the entire library to strict + // IEEE-754 semantics. Diagnostic-only: it also moves the + // -mt off output away from the fast-math reference, so this is + // useful to localize drift sources but not to ship as a default. + if (@available(macOS 15.0, *)) { + options.mathMode = MTLMathModeSafe; + fprintf(stderr, "ds4: Metal shader library math mode = safe (strict IEEE-754) by DS4_METAL_MATH_SAFE\n"); + } else { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + options.fastMathEnabled = NO; +#pragma clang diagnostic pop + fprintf(stderr, "ds4: Metal shader library fast-math disabled by DS4_METAL_MATH_SAFE (pre-macOS 15)\n"); + } + } + + if (drift_hc_stable) macros[@"DS4_METAL_HC_STABLE"] = @"1"; + if (drift_norm_unify) macros[@"DS4_METAL_NORM_RSQRT_DISABLE"] = @"1"; + if (drift_kv_raw_f32) macros[@"DS4_METAL_KV_RAW_F32"] = @"1"; + if (drift_rope_exp2_log2) macros[@"DS4_METAL_ROPE_EXP2_LOG2"] = @"1"; + if (drift_tensor_matmul_off) { + // Recompile without DS4_METAL_HAS_TENSOR so the cooperative-tensor + // matmul branches are excluded from this build, isolating the + // simdgroup_float8x8 path for an A/B vs the Tensor matmul on M5. + // Also flip g_metal4_tensor_api_enabled so the host dispatch + // skips _mpp kernel lookups that are no longer compiled. + [macros removeObjectForKey:@"DS4_METAL_HAS_TENSOR"]; + g_metal4_tensor_api_enabled = 0; + fprintf(stderr, "ds4: Metal 4 cooperative-tensor matmul disabled by DS4_METAL_TENSOR_MATMUL_DISABLE\n"); + } + fprintf(stderr, + "ds4: drift-patch flags hc_stable=%s norm_unify=%s kv_raw_f32=%s rope_exp2_log2=%s math_safe=%s tensor_matmul=%s\n", + drift_hc_stable ? "on" : "off", + drift_norm_unify ? "on" : "off", + drift_kv_raw_f32 ? "on" : "off", + drift_rope_exp2_log2 ? "on" : "off", + drift_math_safe ? "on" : "off", + (g_metal4_tensor_api_enabled && !drift_tensor_matmul_off) ? "on" : "off"); + options.preprocessorMacros = macros; id library = [g_device newLibraryWithSource:source options:options error:&error]; if (!library) { fprintf(stderr, "ds4: Metal shader compilation failed: %s\n", @@ -2926,6 +4076,23 @@ int ds4_gpu_init(void) { return 0; } + fn = [library newFunctionWithName:@"kernel_dsv4_moe_sum6_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_moe_sum6_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + + g_moe_sum6_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_moe_sum6_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_moe_sum6_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + MTLFunctionConstantValues *bin_constants = [[MTLFunctionConstantValues alloc] init]; int16_t bin_op = 0; int16_t bin_f = 1; @@ -3736,6 +4903,8 @@ int ds4_gpu_init(void) { ds4_gpu_get_pipeline("kernel_dsv4_indexed_mixed_attention_heads8"); g_dsv4_indexed_attention_heads8_rb4_pipeline = ds4_gpu_get_pipeline("kernel_dsv4_indexed_mixed_attention_heads8_rb4"); + g_dsv4_indexed_attention_heads8_rb16_pipeline = + ds4_gpu_get_pipeline("kernel_dsv4_indexed_mixed_attention_heads8_rb16"); g_dsv4_softplus_sqrt_pipeline = ds4_gpu_get_pipeline("kernel_dsv4_softplus_sqrt_f32_4"); g_dsv4_router_finalize_one_pipeline = @@ -3749,6 +4918,7 @@ int ds4_gpu_init(void) { !g_dsv4_sort_i32_rows_asc_pipeline || !g_dsv4_indexed_attention_heads8_pipeline || !g_dsv4_indexed_attention_heads8_rb4_pipeline || + !g_dsv4_indexed_attention_heads8_rb16_pipeline || !g_dsv4_softplus_sqrt_pipeline || !g_dsv4_router_finalize_one_pipeline || !g_dsv4_router_weights_one_pipeline || @@ -3939,6 +5109,7 @@ int ds4_gpu_synchronize(void) { if (g_batch_cb) return ds4_gpu_end_commands(); if ([g_pending_cbs count] != 0) { int ok = ds4_gpu_wait_pending_command_buffers("synchronize"); + if (ok) ds4_gpu_mpp_compare_drain("synchronize"); [g_transient_buffers removeAllObjects]; return ok; } @@ -3971,6 +5142,7 @@ void ds4_gpu_cleanup(void) { g_cpy_f16_f32_pipeline = nil; g_swiglu_pipeline = nil; g_add_pipeline = nil; + g_moe_sum6_pipeline = nil; g_mul_pipeline = nil; g_bin_mul_scalar_pipeline = nil; g_bin_div_row_pipeline = nil; @@ -3999,9 +5171,6 @@ void ds4_gpu_cleanup(void) { g_moe_mul_mv_id_q4_k_pair_pipeline = nil; g_moe_mul_mv_id_q4_k_pair_swiglu_pipeline = nil; g_moe_mul_mv_id_q4_k_sum6_pipeline = nil; - g_moe_mul_mm_id_iq2_xxs_pipeline = nil; - g_moe_mul_mm_id_q2_k_pipeline = nil; - g_moe_mul_mm_id_q4_k_pipeline = nil; g_rope_tail_batch_pipeline = nil; g_dsv4_fp8_kv_quantize_pipeline = nil; g_dsv4_kv_fp8_store_pipeline = nil; @@ -4020,6 +5189,7 @@ void ds4_gpu_cleanup(void) { g_dsv4_sort_i32_rows_asc_pipeline = nil; g_dsv4_indexed_attention_heads8_pipeline = nil; g_dsv4_indexed_attention_heads8_rb4_pipeline = nil; + g_dsv4_indexed_attention_heads8_rb16_pipeline = nil; g_dsv4_softplus_sqrt_pipeline = nil; g_dsv4_router_finalize_one_pipeline = nil; g_dsv4_router_weights_one_pipeline = nil; @@ -4095,6 +5265,8 @@ void ds4_gpu_cleanup(void) { g_queue = nil; g_device = nil; g_initialized = 0; + ds4_gpu_mpp_reset_reports(); + ds4_gpu_mpp_compare_reset(); } } @@ -4916,7 +6088,7 @@ int ds4_gpu_dsv4_topk_mask_tensor( return 1; } -int ds4_gpu_matmul_q8_0_tensor( +static int ds4_gpu_matmul_q8_0_legacy_tensor( ds4_gpu_tensor *out, const void *model_map, uint64_t model_size, @@ -5015,18 +6187,178 @@ int ds4_gpu_matmul_q8_0_tensor( threadsPerThreadgroup:MTLSizeMake(32, (NSUInteger)nsg, 1)]; ds4_gpu_end_compute_encoder(cb, enc); - if (!ds4_gpu_finish_command_buffer(cb, owned, "Q8_0 tensor mul_mv_ext")) { - return 0; - } + if (!ds4_gpu_finish_command_buffer(cb, owned, "Q8_0 tensor mul_mv_ext")) { + return 0; + } + return 1; + } + + const bool bc_inp = (in_dim % 32u) != 0; + const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + id pipeline = + ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_q8_0_f32", bc_inp, bc_out); + if (!pipeline) return 0; + + ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; + [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:2]; + [enc setBuffer:outbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; + [enc setThreadgroupMemoryLength:(bc_out ? 8192u : 6144u) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + ((NSUInteger)out_dim + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Q8_0 tensor matmul")) { + return 0; + } + } + + return 1; +} + +static void ds4_gpu_mpp_compare_q8_0_matmul( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!ds4_gpu_mpp_compare_route_matches("q8")) return; + const uint64_t out_bytes = n_tok * out_dim * sizeof(float); + ds4_gpu_tensor *ref = ds4_gpu_tensor_alloc(out_bytes); + ds4_gpu_tensor *cand = ds4_gpu_mpp_compare_snapshot_buffer(ds4_gpu_tensor_buffer(out), + ds4_gpu_tensor_offset(out), + out_bytes); + if (!ref || !cand) { + ds4_gpu_tensor_free(ref); + ds4_gpu_tensor_free(cand); + return; + } + + if (ds4_gpu_matmul_q8_0_legacy_tensor(ref, model_map, model_size, + weight_offset, in_dim, out_dim, + x, n_tok)) { + char fallback[128]; + snprintf(fallback, sizeof(fallback), + "q8 weight_off=%llu in=%llu out=%llu tok=%llu", + (unsigned long long)weight_offset, + (unsigned long long)in_dim, + (unsigned long long)out_dim, + (unsigned long long)n_tok); + ds4_gpu_mpp_compare_register("q8", + fallback, + ref, + cand, + n_tok * out_dim, + n_tok, + out_dim, + in_dim); + if (!g_batch_cb) ds4_gpu_mpp_compare_drain("q8 compare"); + } + ds4_gpu_tensor_free(cand); + ds4_gpu_tensor_free(ref); +} + +int ds4_gpu_matmul_q8_0_tensor( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if ((in_dim & 31u) != 0 || + in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok > UINT32_MAX) { + return 0; + } + + if (ds4_gpu_can_use_mpp_q8_0_matmul(n_tok)) { + if (ds4_gpu_matmul_q8_0_mpp_tensor(out, model_map, model_size, weight_offset, + in_dim, out_dim, x, n_tok)) { + ds4_gpu_mpp_compare_q8_0_matmul(out, model_map, model_size, + weight_offset, in_dim, out_dim, + x, n_tok); return 1; } + ds4_gpu_warn_mpp_fallback(); + } + + return ds4_gpu_matmul_q8_0_legacy_tensor(out, model_map, model_size, + weight_offset, in_dim, out_dim, + x, n_tok); +} + +int ds4_gpu_matmul_q8_0_mpp_tensor( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!g_metal4_tensor_api_enabled) return 0; + if ((in_dim & 31u) != 0 || n_tok <= 8 || + in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok > UINT32_MAX) { + return 0; + } + + @autoreleasepool { + id xbuf = ds4_gpu_tensor_buffer(x); + id outbuf = ds4_gpu_tensor_buffer(out); + const uint64_t x_bytes = n_tok * in_dim * sizeof(float); + const uint64_t out_bytes = n_tok * out_dim * sizeof(float); + if (!xbuf || !outbuf || + ds4_gpu_tensor_bytes(x) < x_bytes || + ds4_gpu_tensor_bytes(out) < out_bytes) { + fprintf(stderr, "ds4: Metal Tensor Q8_0 matmul received undersized activation buffers\n"); + return 0; + } + const uint64_t blocks = in_dim / 32; + const uint64_t row_bytes = blocks * 34; + const uint64_t weight_bytes = out_dim * row_bytes; + if (weight_offset > model_size || weight_bytes > model_size - weight_offset) { + fprintf(stderr, "ds4: Metal Tensor Q8_0 matmul range is outside the mapped model\n"); + return 0; + } + + uint64_t inner_offset = 0; + id wbuf = ds4_gpu_wrap_model_range(model_map, model_size, weight_offset, weight_bytes, &inner_offset); + if (!wbuf) return 0; + + const uint32_t tile_n = ds4_gpu_mpp_q8_0_tile_n_for_tokens(n_tok); + const bool direct_rhs = + (tile_n == 32u || tile_n == 64u) && + ds4_gpu_mpp_q8_0_direct_rhs(); const bool bc_inp = (in_dim % 32u) != 0; - const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + const bool bc_out = (out_dim % 64u) != 0 || (n_tok % tile_n) != 0; + const char *pipeline_name = direct_rhs ? + (tile_n == 64u ? + "kernel_mul_mm_q8_0_f32_mpp_direct_rhs_n64" : + "kernel_mul_mm_q8_0_f32_mpp_direct_rhs") : + (tile_n == 64u ? + "kernel_mul_mm_q8_0_f32_mpp_n64" : + "kernel_mul_mm_q8_0_f32_mpp"); id pipeline = - ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_q8_0_f32", bc_inp, bc_out); + ds4_gpu_get_mul_mm_pipeline(pipeline_name, bc_inp, bc_out); if (!pipeline) return 0; + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); id enc = ds4_gpu_compute_encoder(cb); @@ -5035,16 +6367,14 @@ int ds4_gpu_matmul_q8_0_tensor( [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:2]; [enc setBuffer:outbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; - [enc setThreadgroupMemoryLength:(bc_out ? 8192u : 6144u) atIndex:0]; - [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + [enc setThreadgroupMemoryLength:(direct_rhs ? 4096u : (tile_n == 64 ? 8192u : 6144u)) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + (NSUInteger)tile_n - 1u) / (NSUInteger)tile_n, ((NSUInteger)out_dim + 63u) / 64u, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; ds4_gpu_end_compute_encoder(cb, enc); - if (!ds4_gpu_finish_command_buffer(cb, owned, "Q8_0 tensor matmul")) { - return 0; - } + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal Tensor Q8_0 matmul")) return 0; } return 1; @@ -5241,6 +6571,41 @@ int ds4_gpu_matmul_f16_tensor( const bool bc_inp = (in_dim % 32u) != 0; const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + const bool mpp_f16_shape = + in_dim == 4096u && !bc_inp && + (out_dim == 128u || + (ds4_gpu_mpp_f16_wide_matmul() && (out_dim % 64u) == 0)); + /* Keep wider compressor MPP opt-in until full-model drift and speed are measured. */ + if (mpp_f16_shape && + ds4_gpu_use_mpp_f16_compressor_matmul()) { + const bool direct_rhs = ds4_gpu_mpp_f16_direct_rhs(); + id pipeline = + ds4_gpu_get_mul_mm_pipeline(direct_rhs ? + "kernel_mul_mm_f16_f32_mpp_direct_rhs" : + "kernel_mul_mm_f16_f32_mpp", + false, + bc_out); + if (pipeline) { + ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; + [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:2]; + [enc setBuffer:outbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; + [enc setThreadgroupMemoryLength:(direct_rhs ? 4096u : 6144u) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + ((NSUInteger)out_dim + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal Tensor F16 compressor matmul")) return 0; + return 1; + } + } + id pipeline = ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_f16_f32", bc_inp, bc_out); if (!pipeline) return 0; @@ -5278,12 +6643,93 @@ int ds4_gpu_matmul_f16_pair_tensor( const ds4_gpu_tensor *x, uint64_t n_tok) { if (!g_initialized && !ds4_gpu_init()) return 0; - if (in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok != 1 || (in_dim & 3u) != 0) return 0; + if (in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok == 0 || (in_dim & 3u) != 0) return 0; @autoreleasepool { id xbuf = ds4_gpu_tensor_buffer(x); id outabuf = ds4_gpu_tensor_buffer(out_a); id outbbuf = ds4_gpu_tensor_buffer(out_b); + if (n_tok != 1) { + const bool use_wide_mpp_pair = ds4_gpu_mpp_f16_wide_matmul(); + const bool pair_shape = + in_dim == 4096u && (out_dim % 64u) == 0; + if (n_tok <= 8 || + !pair_shape || + !ds4_gpu_mpp_f16_pair_matmul() || + !ds4_gpu_use_mpp_f16_compressor_matmul()) { + return 0; + } + + const uint64_t x_bytes = n_tok * in_dim * sizeof(float); + const uint64_t out_bytes = n_tok * out_dim * sizeof(float); + if (!xbuf || !outabuf || !outbbuf || + ds4_gpu_tensor_bytes(x) < x_bytes || + ds4_gpu_tensor_bytes(out_a) < out_bytes || + ds4_gpu_tensor_bytes(out_b) < out_bytes) { + fprintf(stderr, "ds4: Metal F16 paired Tensor matmul received undersized activation buffers\n"); + return 0; + } + + const uint64_t row_bytes = in_dim * sizeof(uint16_t); + const uint64_t weight_bytes = row_bytes * out_dim; + if (weight_a_offset > model_size || weight_bytes > model_size - weight_a_offset || + weight_b_offset > model_size || weight_bytes > model_size - weight_b_offset) { + fprintf(stderr, "ds4: Metal F16 paired Tensor matmul range is outside the mapped model\n"); + return 0; + } + + uint64_t inner_a = 0; + uint64_t inner_b = 0; + id wabuf = ds4_gpu_wrap_model_range(model_map, model_size, + weight_a_offset, weight_bytes, + &inner_a); + id wbbuf = ds4_gpu_wrap_model_range(model_map, model_size, + weight_b_offset, weight_bytes, + &inner_b); + if (!wabuf || !wbbuf) return 0; + + const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + id pipeline = + ds4_gpu_get_mul_mm_pipeline(use_wide_mpp_pair ? + "kernel_mul_mm_f16_f32_pair_mpp" : + "kernel_mul_mm_f16_f32_pair", + false, + bc_out); + if (!pipeline) return 0; + if (!g_mpp_f16_pair_reported) { + fprintf(stderr, "ds4: Metal paired F16 compressor matmul enabled%s\n", + use_wide_mpp_pair ? " with Tensor wide route" : ""); + g_mpp_f16_pair_reported = 1; + } + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:wabuf offset:(NSUInteger)inner_a atIndex:1]; + [enc setBuffer:wbbuf offset:(NSUInteger)inner_b atIndex:2]; + [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:3]; + [enc setBuffer:outabuf offset:ds4_gpu_tensor_offset(out_a) atIndex:4]; + [enc setBuffer:outbbuf offset:ds4_gpu_tensor_offset(out_b) atIndex:5]; + const NSUInteger smem = use_wide_mpp_pair ? + (NSUInteger)((64u * 32u * 2u + 32u * 32u) * sizeof(uint16_t)) : + (NSUInteger)12288u; + [enc setThreadgroupMemoryLength:smem atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + ((NSUInteger)out_dim + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal F16 paired matmul")) return 0; + return 1; + } + const uint64_t x_bytes = in_dim * sizeof(float); const uint64_t out_bytes = out_dim * sizeof(float); if (!xbuf || !outabuf || !outbbuf || @@ -7949,6 +9395,73 @@ static int ds4_gpu_encode_fill_f32_rows( return 1; } +static void ds4_gpu_mpp_compare_attn_out_low( + id cb, + const ds4_gpu_mul_mm_id_args *mm_args, + id out_a_buf, + NSUInteger out_a_inner, + const ds4_gpu_tensor *heads, + ds4_gpu_tensor *low, + uint32_t group_dim, + uint32_t rank, + uint32_t n_groups, + uint32_t n_tokens) { + if (!ds4_gpu_mpp_compare_route_matches("attn_out")) return; + const NSUInteger ids_bytes = (NSUInteger)n_tokens * (NSUInteger)n_groups * sizeof(int32_t); + id ids_buffer = ds4_gpu_new_transient_buffer(ids_bytes, "attention output compare group ids"); + ds4_gpu_tensor *ref = ds4_gpu_tensor_alloc((uint64_t)n_tokens * n_groups * rank * sizeof(float)); + ds4_gpu_tensor *cand = ds4_gpu_mpp_compare_snapshot_buffer(ds4_gpu_tensor_buffer(low), + ds4_gpu_tensor_offset(low), + (uint64_t)n_tokens * n_groups * rank * sizeof(float)); + if (!ids_buffer || !ref || !cand) { + ds4_gpu_tensor_free(ref); + ds4_gpu_tensor_free(cand); + return; + } + int32_t *ids = (int32_t *)[ids_buffer contents]; + for (uint32_t t = 0; t < n_tokens; t++) { + for (uint32_t group = 0; group < n_groups; group++) { + ids[(uint64_t)t * n_groups + group] = (int32_t)group; + } + } + + ds4_gpu_mul_mm_id_map_args map_args = + ds4_gpu_make_mul_mm_id_map_args(group_dim, + n_groups, + n_groups, + n_groups, + n_tokens); + id map_pipeline = + ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_groups)); + id legacy_pipeline = + ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false, false); + if (map_pipeline && legacy_pipeline && + ds4_gpu_encode_mul_mm_id(cb, + map_pipeline, + legacy_pipeline, + &map_args, + mm_args, + out_a_buf, + out_a_inner, + ds4_gpu_tensor_buffer(heads), + ds4_gpu_tensor_offset(heads), + ds4_gpu_tensor_buffer(ref), + ds4_gpu_tensor_offset(ref), + ids_buffer, + 0)) { + ds4_gpu_mpp_compare_register("attn_out", + "attn_out_low", + ref, + cand, + (uint64_t)n_tokens * n_groups * rank, + n_tokens, + (uint64_t)n_groups * rank, + group_dim); + } + ds4_gpu_tensor_free(cand); + ds4_gpu_tensor_free(ref); +} + int ds4_gpu_attention_output_q8_batch_tensor( ds4_gpu_tensor *out, ds4_gpu_tensor *low, @@ -8001,9 +9514,14 @@ int ds4_gpu_attention_output_q8_batch_tensor( const bool use_direct_low = n_tokens < 32u && getenv("DS4_METAL_DISABLE_ATTN_OUT_LOW_DIRECT") == NULL; + /* The tensor tile store is only used on full token tiles; partial tails use the legacy path. */ + const bool use_mpp_low = + n_tokens >= 32u && + (n_tokens % 32u) == 0 && + ds4_gpu_use_mpp_attn_out_low_matmul(); const NSUInteger ids_bytes = (NSUInteger)n_tokens * (NSUInteger)n_groups * sizeof(int32_t); id group_ids_buffer = nil; - if (!use_direct_low) { + if (!use_direct_low && !use_mpp_low) { if (getenv("DS4_METAL_DISABLE_ATTN_OUT_IDS_CACHE") != NULL) { group_ids_buffer = ds4_gpu_new_transient_buffer(ids_bytes, "attention output group ids"); @@ -8073,7 +9591,98 @@ int ds4_gpu_attention_output_q8_batch_tensor( * tokens. This preserves the single-token generation path while * keeping prefill accumulation stable. */ - if (n_tokens >= 32u && ds4_gpu_mul_mm_id_map0_name(n_groups) != NULL) { + if (use_mpp_low) { + ds4_gpu_mul_mm_id_args mm_args = + ds4_gpu_make_mul_mm_id_args((uint32_t)group_dim, + (uint32_t)rank, + n_groups, + row_a_bytes, + (uint64_t)rank * row_a_bytes, + n_groups, + n_groups, + n_tokens); + const uint32_t attn_out_tile_n = ds4_gpu_mpp_attn_out_tile_n(); + const bool attn_out_direct_rhs = + (attn_out_tile_n == 32u || attn_out_tile_n == 64u) && + ds4_gpu_mpp_attn_out_direct_rhs(); + const char *attn_out_pipeline_name = attn_out_direct_rhs ? + (attn_out_tile_n == 64u ? + "kernel_attn_out_low_q8_0_mpp_direct_rhs_n64" : + "kernel_attn_out_low_q8_0_mpp_direct_rhs") : + (attn_out_tile_n == 64u ? + "kernel_attn_out_low_q8_0_mpp_n64" : + "kernel_attn_out_low_q8_0_mpp"); + id mm_pipeline = + ds4_gpu_get_mul_mm_id_pipeline(attn_out_pipeline_name, + false, + false); + ok = ds4_gpu_encode_attn_out_low_q8_mpp(cb, + mm_pipeline, + &mm_args, + out_a_buf, + (NSUInteger)out_a_inner, + ds4_gpu_tensor_buffer(heads), + ds4_gpu_tensor_offset(heads), + ds4_gpu_tensor_buffer(low), + ds4_gpu_tensor_offset(low)) != 0; + if (ok) { + ds4_gpu_mpp_compare_attn_out_low(cb, + &mm_args, + out_a_buf, + (NSUInteger)out_a_inner, + heads, + low, + (uint32_t)group_dim, + (uint32_t)rank, + n_groups, + n_tokens); + } + if (!ok) { + ds4_gpu_warn_mpp_fallback(); + if (ds4_gpu_mul_mm_id_map0_name(n_groups) != NULL) { + if (getenv("DS4_METAL_DISABLE_ATTN_OUT_IDS_CACHE") != NULL) { + group_ids_buffer = + ds4_gpu_new_transient_buffer(ids_bytes, "attention output group ids"); + } else if (ds4_gpu_ensure_scratch_buffer(&g_attn_out_group_ids_buffer, + &g_attn_out_group_ids_bytes, + ids_bytes, + "ds4_attention_output_group_ids")) { + group_ids_buffer = g_attn_out_group_ids_buffer; + } + if (group_ids_buffer) { + int32_t *ids = (int32_t *)[group_ids_buffer contents]; + for (uint32_t t = 0; t < n_tokens; t++) { + for (uint32_t group = 0; group < n_groups; group++) { + ids[(uint64_t)t * n_groups + group] = (int32_t)group; + } + } + ds4_gpu_mul_mm_id_map_args map_args = + ds4_gpu_make_mul_mm_id_map_args((uint32_t)group_dim, + n_groups, + n_groups, + n_groups, + n_tokens); + id map_pipeline = + ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_groups)); + id fallback_pipeline = + ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false, false); + ok = ds4_gpu_encode_mul_mm_id(cb, + map_pipeline, + fallback_pipeline, + &map_args, + &mm_args, + out_a_buf, + (NSUInteger)out_a_inner, + ds4_gpu_tensor_buffer(heads), + ds4_gpu_tensor_offset(heads), + ds4_gpu_tensor_buffer(low), + ds4_gpu_tensor_offset(low), + group_ids_buffer, + 0) != 0; + } + } + } + } else if (n_tokens >= 32u && ds4_gpu_mul_mm_id_map0_name(n_groups) != NULL) { ds4_gpu_mul_mm_id_map_args map_args = ds4_gpu_make_mul_mm_id_map_args((uint32_t)group_dim, n_groups, @@ -8092,7 +9701,7 @@ int ds4_gpu_attention_output_q8_batch_tensor( id map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_groups)); id mm_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false); + ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false, false); ok = ds4_gpu_encode_mul_mm_id(cb, map_pipeline, mm_pipeline, @@ -10815,10 +12424,14 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( ds4_gpu_hot_pipeline(g_dsv4_sort_i32_rows_asc_pipeline, "kernel_dsv4_sort_i32_rows_asc"); const bool decode_one_token = n_tokens == 1u; + const bool decode_rb4 = decode_one_token && ds4_gpu_use_indexed_attention_rb4(); id attn_pipeline = - decode_one_token ? + decode_rb4 ? ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_rb4_pipeline, "kernel_dsv4_indexed_mixed_attention_heads8_rb4") : + decode_one_token ? + ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_rb16_pipeline, + "kernel_dsv4_indexed_mixed_attention_heads8_rb16") : ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_pipeline, "kernel_dsv4_indexed_mixed_attention_heads8"); if (!sort_pipeline || !attn_pipeline) return 0; @@ -10899,7 +12512,8 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( atIndex:4]; [enc setBuffer:sinks_buf offset:(NSUInteger)sinks_inner atIndex:5]; [enc setBuffer:headsbuf offset:ds4_gpu_tensor_offset(heads) atIndex:6]; - [enc setThreadgroupMemoryLength:(decode_one_token ? 4u : 1u) * 128u * 4u * sizeof(float) + [enc setThreadgroupMemoryLength:(decode_one_token ? (decode_rb4 ? 4u : 16u) : 1u) * + 128u * 4u * sizeof(float) atIndex:0]; [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)n_tokens, ((NSUInteger)n_head + 7u) / 8u, 1) threadsPerThreadgroup:MTLSizeMake(32, 8, 1)]; @@ -11590,44 +13204,140 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) { } } -static id ds4_gpu_routed_mm_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_pipeline(uint32_t type, bool use_mpp) { + const bool tile_n64 = use_mpp && ds4_gpu_mpp_moe_tile_n() == 64; + const bool fast_layout = use_mpp && !tile_n64 && ds4_gpu_mpp_moe_fast_layout(); switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: - if (!g_moe_mul_mm_id_iq2_xxs_pipeline) { - g_moe_mul_mm_id_iq2_xxs_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32", false); - } - return g_moe_mul_mm_id_iq2_xxs_pipeline; + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_iq2_xxs_f32_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_iq2_xxs_f32_n64" : + "kernel_mul_mm_id_iq2_xxs_f32", + false, + use_mpp); case DS4_METAL_TENSOR_Q2_K: - if (!g_moe_mul_mm_id_q2_k_pipeline) { - g_moe_mul_mm_id_q2_k_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32", false); - } - return g_moe_mul_mm_id_q2_k_pipeline; + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q2_K_f32_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q2_K_f32_n64" : + "kernel_mul_mm_id_q2_K_f32", + false, + use_mpp); case DS4_METAL_TENSOR_Q4_K: - if (!g_moe_mul_mm_id_q4_k_pipeline) { - g_moe_mul_mm_id_q4_k_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32", false); - } - return g_moe_mul_mm_id_q4_k_pipeline; + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q4_K_f32_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q4_K_f32_n64" : + "kernel_mul_mm_id_q4_K_f32", + false, + use_mpp); + default: + return nil; + } +} + +static id ds4_gpu_routed_mm_pair_mpp_pipeline(uint32_t type) { + switch (type) { + case DS4_METAL_TENSOR_IQ2_XXS: + return ds4_gpu_get_pipeline("kernel_mul_mm_id_iq2_xxs_f32_pair_mpp"); + case DS4_METAL_TENSOR_Q2_K: + return ds4_gpu_get_pipeline("kernel_mul_mm_id_q2_K_f32_pair_mpp"); + case DS4_METAL_TENSOR_Q4_K: + return ds4_gpu_get_pipeline("kernel_mul_mm_id_q4_K_f32_pair_mpp"); default: return nil; } } -static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type, bool use_mpp) { + const bool tile_n64 = use_mpp && ds4_gpu_mpp_moe_tile_n() == 64; + const bool fast_layout = use_mpp && !tile_n64 && ds4_gpu_mpp_moe_fast_layout(); switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16", false); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_iq2_xxs_f16_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_iq2_xxs_f16_n64" : + "kernel_mul_mm_id_iq2_xxs_f16", + false, + use_mpp); case DS4_METAL_TENSOR_Q2_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q2_K_f16_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q2_K_f16_n64" : + "kernel_mul_mm_id_q2_K_f16", + false, + use_mpp); case DS4_METAL_TENSOR_Q4_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q4_K_f16_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q4_K_f16_n64" : + "kernel_mul_mm_id_q4_K_f16", + false, + use_mpp); default: return nil; } } +static void ds4_gpu_mpp_compare_moe_mm( + const char *route, + const char *stage, + uint32_t type, + bool f16_rhs, + id cb, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id cand, + NSUInteger cand_off, + uint64_t elements, + uint64_t dim0, + uint64_t dim1, + uint64_t dim2) { + if (!ds4_gpu_mpp_compare_route_matches(route)) return; + if (elements == 0) return; + ds4_gpu_tensor *ref = ds4_gpu_tensor_alloc(elements * sizeof(float)); + ds4_gpu_tensor *cand_snapshot = ds4_gpu_mpp_compare_snapshot_buffer(cand, + cand_off, + elements * sizeof(float)); + if (!ref || !cand_snapshot) { + ds4_gpu_tensor_free(ref); + ds4_gpu_tensor_free(cand_snapshot); + return; + } + + id legacy_pipeline = f16_rhs ? + ds4_gpu_routed_mm_f16_rhs_pipeline(type, false) : + ds4_gpu_routed_mm_pipeline(type, false); + if (legacy_pipeline && + ds4_gpu_encode_mul_mm_id_mapped(cb, + legacy_pipeline, + mm_args, + src0, + src0_off, + src1, + src1_off, + ds4_gpu_tensor_buffer(ref), + ds4_gpu_tensor_offset(ref))) { + ds4_gpu_mpp_compare_register(route, + stage, + ref, + cand_snapshot, + elements, + dim0, + dim1, + dim2); + } + ds4_gpu_tensor_free(cand_snapshot); + ds4_gpu_tensor_free(ref); +} + static int ds4_gpu_encode_mul_mv_id( id cb, id pipeline, @@ -11919,7 +13629,7 @@ static int ds4_gpu_encode_mul_mm_id_map( return 1; } -static int ds4_gpu_encode_mul_mm_id_mapped( +static int ds4_gpu_encode_mul_mm_id_mapped_tile( id cb, id mm_pipeline, const ds4_gpu_mul_mm_id_args *mm_args, @@ -11928,13 +13638,15 @@ static int ds4_gpu_encode_mul_mm_id_mapped( id src1, NSUInteger src1_off, id dst, - NSUInteger dst_off) { + NSUInteger dst_off, + uint32_t tile_n) { if (!cb || !mm_pipeline || !mm_args || !src0 || !src1 || !dst || !g_moe_id_map_buffer || mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { return 0; } + if (tile_n != 64u) tile_n = 32u; const NSUInteger tpe_bytes = (NSUInteger)mm_args->ne02 * sizeof(int32_t); const NSUInteger hids_bytes = (NSUInteger)mm_args->ne02 * (NSUInteger)mm_args->ne21 * sizeof(int32_t); @@ -11951,6 +13663,53 @@ static int ds4_gpu_encode_mul_mm_id_mapped( [enc setBuffer:g_moe_id_map_buffer offset:0 atIndex:3]; [enc setBuffer:g_moe_id_map_buffer offset:tpe_bytes atIndex:4]; [enc setBuffer:dst offset:dst_off atIndex:5]; + [enc setThreadgroupMemoryLength:(tile_n == 64u ? 16384u : 8192u) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + (NSUInteger)tile_n - 1u) / (NSUInteger)tile_n, + ((NSUInteger)mm_args->ne0 + 63u) / 64u, + (NSUInteger)mm_args->ne02) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + +static int ds4_gpu_encode_mul_mm_id_pair_mpp( + id cb, + id pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0_gate, + NSUInteger src0_gate_off, + id src0_up, + NSUInteger src0_up_off, + id src1, + NSUInteger src1_off, + id dst_gate, + NSUInteger dst_gate_off, + id dst_up, + NSUInteger dst_up_off) { + if (!cb || !pipeline || !mm_args || !src0_gate || !src0_up || !src1 || + !dst_gate || !dst_up || !g_moe_id_map_buffer || + mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || + mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { + return 0; + } + + const NSUInteger tpe_bytes = (NSUInteger)mm_args->ne02 * sizeof(int32_t); + const NSUInteger hids_bytes = (NSUInteger)mm_args->ne02 * (NSUInteger)mm_args->ne21 * sizeof(int32_t); + if (tpe_bytes > NSUIntegerMax - hids_bytes || + g_moe_id_map_bytes < tpe_bytes + hids_bytes) { + return 0; + } + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:mm_args length:sizeof(*mm_args) atIndex:0]; + [enc setBuffer:src0_gate offset:src0_gate_off atIndex:1]; + [enc setBuffer:src0_up offset:src0_up_off atIndex:2]; + [enc setBuffer:src1 offset:src1_off atIndex:3]; + [enc setBuffer:g_moe_id_map_buffer offset:0 atIndex:4]; + [enc setBuffer:g_moe_id_map_buffer offset:tpe_bytes atIndex:5]; + [enc setBuffer:dst_gate offset:dst_gate_off atIndex:6]; + [enc setBuffer:dst_up offset:dst_up_off atIndex:7]; [enc setThreadgroupMemoryLength:8192u atIndex:0]; [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + 31u) / 32u, ((NSUInteger)mm_args->ne0 + 63u) / 64u, @@ -11960,6 +13719,64 @@ static int ds4_gpu_encode_mul_mm_id_mapped( return 1; } +static int ds4_gpu_encode_mul_mm_id_mapped( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off) { + return ds4_gpu_encode_mul_mm_id_mapped_tile(cb, + mm_pipeline, + mm_args, + src0, + src0_off, + src1, + src1_off, + dst, + dst_off, + 32u); +} + +static int ds4_gpu_encode_attn_out_low_q8_mpp( + id cb, + id pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off) { + if (!cb || !pipeline || !mm_args || !src0 || !src1 || !dst || + mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || + mm_args->ne02 <= 0 || mm_args->ne1 <= 0 || mm_args->ne21 <= 0) { + return 0; + } + + const uint32_t tile_n = ds4_gpu_mpp_attn_out_tile_n(); + const bool direct_rhs = + (tile_n == 32u || tile_n == 64u) && + ds4_gpu_mpp_attn_out_direct_rhs(); + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:mm_args length:sizeof(*mm_args) atIndex:0]; + [enc setBuffer:src0 offset:src0_off atIndex:1]; + [enc setBuffer:src1 offset:src1_off atIndex:2]; + [enc setBuffer:dst offset:dst_off atIndex:3]; + [enc setThreadgroupMemoryLength:(direct_rhs ? 4096u : (tile_n == 64 ? 8192u : 6144u)) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + (NSUInteger)tile_n - 1u) / (NSUInteger)tile_n, + ((NSUInteger)mm_args->ne0 + 63u) / 64u, + (NSUInteger)mm_args->ne02) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + static int ds4_gpu_encode_swiglu_flat( id cb, id gate, @@ -12050,6 +13867,42 @@ static int ds4_gpu_encode_moe_swiglu_weight( return 1; } +static int ds4_gpu_encode_moe_sum6( + id cb, + id experts, + NSUInteger experts_off, + id out, + NSUInteger out_off, + uint32_t out_dim, + uint32_t n_tokens) { + if (!cb || !experts || !out || out_dim == 0 || n_tokens == 0) return 0; + + if (!g_moe_sum6_pipeline) return 0; + + const uint64_t out_row_bytes = (uint64_t)out_dim * sizeof(float); + ds4_gpu_dsv4_moe_sum6_args args = { + .width = out_dim, + .tokens = n_tokens, + .src_token_stride = 6u * out_row_bytes, + .dst_token_stride = out_row_bytes, + }; + + NSUInteger nth = g_moe_sum6_pipeline.maxTotalThreadsPerThreadgroup; + if (nth > 256u) nth = 256u; + if (nth > out_dim) nth = out_dim; + if (nth == 0) nth = 1u; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_moe_sum6_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:experts offset:experts_off atIndex:1]; + [enc setBuffer:out offset:out_off atIndex:2]; + [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)n_tokens, 1, 1) + threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + static ds4_gpu_bin_args ds4_gpu_make_moe_add_args( uint32_t out_dim, uint32_t n_tokens, @@ -12100,6 +13953,18 @@ static int ds4_gpu_encode_moe_sum_experts( const uint64_t out_row_bytes = (uint64_t)out_dim * sizeof(float); const uint64_t expert_token_stride = (uint64_t)n_expert * out_row_bytes; + if (n_expert == 6 && + getenv("DS4_METAL_MOE_SUM6_DISABLE") == NULL && + ds4_gpu_encode_moe_sum6(cb, + experts, + experts_off, + out, + out_off, + out_dim, + n_tokens)) { + return 1; + } + ds4_gpu_bin_args first = ds4_gpu_make_moe_add_args(out_dim, n_tokens, expert_token_stride, expert_token_stride, out_row_bytes); if (!ds4_gpu_encode_bin_f32_rows(cb, @@ -13064,6 +14929,7 @@ int ds4_gpu_routed_moe_batch_tensor( uint32_t n_expert, float clamp, const ds4_gpu_tensor *x, + uint32_t layer_index, uint32_t n_tokens, bool *mid_is_f16) { if (!g_initialized && !ds4_gpu_init()) return 0; @@ -13130,6 +14996,8 @@ int ds4_gpu_routed_moe_batch_tensor( id gate_mv_pipeline = ds4_gpu_routed_mv_pipeline(gate_type); id down_mv_pipeline = ds4_gpu_routed_mv_pipeline(down_type); id gate_mm_pipeline = nil; + id up_mm_pipeline = nil; + id gate_up_pair_mm_pipeline = nil; id down_mm_pipeline = nil; if (gate_nr0 == 0 || down_nr0 == 0 || !gate_mv_pipeline || !down_mv_pipeline) { fprintf(stderr, "ds4: unsupported Metal routed batch MoE quant types gate=%u down=%u\n", @@ -13166,6 +15034,7 @@ int ds4_gpu_routed_moe_batch_tensor( ds4_gpu_mul_mm_id_args gate_mm_args = { 0 }; ds4_gpu_mul_mm_id_args down_mm_args = { 0 }; id map_pipeline = nil; + const int moe_mpp_mask = ds4_gpu_mpp_routed_moe_mask_for_layer(layer_index); /* * The grouped routed-MoE matmul loads activation tiles as half before * using SIMD-group MMA. Store the SwiGLU/route-weight intermediate in @@ -13175,6 +15044,19 @@ int ds4_gpu_routed_moe_batch_tensor( */ const bool request_mid_f16 = !g_quality_mode && getenv("DS4_METAL_MOE_MID_F32") == NULL; + const uint32_t moe_mpp_tile_n = ds4_gpu_mpp_moe_tile_n(); + const uint32_t gate_mm_tile_n = + (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0 ? moe_mpp_tile_n : 32u; + const uint32_t up_mm_tile_n = + (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0 ? moe_mpp_tile_n : 32u; + const uint32_t down_mm_tile_n = + (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0 ? moe_mpp_tile_n : 32u; + const bool use_gate_up_pair_mpp = + ds4_gpu_mpp_moe_pair_gate_up() && + (moe_mpp_mask & (DS4_METAL_MOE_MPP_GATE | DS4_METAL_MOE_MPP_UP)) == + (DS4_METAL_MOE_MPP_GATE | DS4_METAL_MOE_MPP_UP) && + gate_mm_tile_n == 32u && + up_mm_tile_n == 32u; if (use_mm_id) { gate_map_args = ds4_gpu_make_mul_mm_id_map_args(expert_in_dim, 256, 1, n_expert, n_tokens); @@ -13189,11 +15071,22 @@ int ds4_gpu_routed_moe_batch_tensor( request_mid_f16 ? sizeof(uint16_t) : sizeof(float)); map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_expert)); - gate_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type); + if (use_gate_up_pair_mpp) { + gate_up_pair_mm_pipeline = ds4_gpu_routed_mm_pair_mpp_pipeline(gate_type); + } else { + gate_mm_pipeline = ds4_gpu_routed_mm_pipeline( + gate_type, + (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0); + up_mm_pipeline = ds4_gpu_routed_mm_pipeline( + gate_type, + (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0); + } down_mm_pipeline = request_mid_f16 ? - ds4_gpu_routed_mm_f16_rhs_pipeline(down_type) : - ds4_gpu_routed_mm_pipeline(down_type); - if (!map_pipeline || !gate_mm_pipeline || !down_mm_pipeline) { + ds4_gpu_routed_mm_f16_rhs_pipeline(down_type, (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0) : + ds4_gpu_routed_mm_pipeline(down_type, (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0); + if (!map_pipeline || + (use_gate_up_pair_mpp ? !gate_up_pair_mm_pipeline : (!gate_mm_pipeline || !up_mm_pipeline)) || + !down_mm_pipeline) { return 0; } } @@ -13260,8 +15153,57 @@ int ds4_gpu_routed_moe_batch_tensor( selectedbuf, ds4_gpu_tensor_offset(selected)); DS4_METAL_PROFILE_MOE_STAGE("map"); - if (ok) { - ok = ds4_gpu_encode_mul_mm_id_mapped(cb, + if (ok && use_gate_up_pair_mpp) { + ok = ds4_gpu_encode_mul_mm_id_pair_mpp(cb, + gate_up_pair_mm_pipeline, + &gate_mm_args, + gate_buf, + (NSUInteger)gate_inner, + up_buf, + (NSUInteger)up_inner, + xbuf, + ds4_gpu_tensor_offset(x), + gatebuf, + ds4_gpu_tensor_offset(gate), + upbuf, + ds4_gpu_tensor_offset(up)); + if (ok) { + ds4_gpu_mpp_compare_moe_mm("moe_gate", + "moe_gate", + gate_type, + false, + cb, + &gate_mm_args, + gate_buf, + (NSUInteger)gate_inner, + xbuf, + ds4_gpu_tensor_offset(x), + gatebuf, + ds4_gpu_tensor_offset(gate), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + ds4_gpu_mpp_compare_moe_mm("moe_up", + "moe_up", + gate_type, + false, + cb, + &gate_mm_args, + up_buf, + (NSUInteger)up_inner, + xbuf, + ds4_gpu_tensor_offset(x), + upbuf, + ds4_gpu_tensor_offset(up), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + } + DS4_METAL_PROFILE_MOE_STAGE("gate_up_pair"); + } else if (ok) { + ok = ds4_gpu_encode_mul_mm_id_mapped_tile(cb, gate_mm_pipeline, &gate_mm_args, gate_buf, @@ -13269,19 +15211,57 @@ int ds4_gpu_routed_moe_batch_tensor( xbuf, ds4_gpu_tensor_offset(x), gatebuf, - ds4_gpu_tensor_offset(gate)); + ds4_gpu_tensor_offset(gate), + gate_mm_tile_n); + if (ok && (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0) { + ds4_gpu_mpp_compare_moe_mm("moe_gate", + "moe_gate", + gate_type, + false, + cb, + &gate_mm_args, + gate_buf, + (NSUInteger)gate_inner, + xbuf, + ds4_gpu_tensor_offset(x), + gatebuf, + ds4_gpu_tensor_offset(gate), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + } DS4_METAL_PROFILE_MOE_STAGE("gate"); } - if (ok) { - ok = ds4_gpu_encode_mul_mm_id_mapped(cb, - gate_mm_pipeline, + if (ok && !use_gate_up_pair_mpp) { + ok = ds4_gpu_encode_mul_mm_id_mapped_tile(cb, + up_mm_pipeline, &gate_mm_args, up_buf, (NSUInteger)up_inner, xbuf, ds4_gpu_tensor_offset(x), upbuf, - ds4_gpu_tensor_offset(up)); + ds4_gpu_tensor_offset(up), + up_mm_tile_n); + if (ok && (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0) { + ds4_gpu_mpp_compare_moe_mm("moe_up", + "moe_up", + gate_type, + false, + cb, + &gate_mm_args, + up_buf, + (NSUInteger)up_inner, + xbuf, + ds4_gpu_tensor_offset(x), + upbuf, + ds4_gpu_tensor_offset(up), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + } DS4_METAL_PROFILE_MOE_STAGE("up"); } } else if (use_tiny_pair_mv) { @@ -13453,7 +15433,7 @@ int ds4_gpu_routed_moe_batch_tensor( down_smem, 2); } else if (use_mm_id) { - ok = ds4_gpu_encode_mul_mm_id_mapped(cb, + ok = ds4_gpu_encode_mul_mm_id_mapped_tile(cb, down_mm_pipeline, &down_mm_args, down_buf, @@ -13461,7 +15441,26 @@ int ds4_gpu_routed_moe_batch_tensor( midbuf, ds4_gpu_tensor_offset(mid), down_dst, - down_dst_off); + down_dst_off, + down_mm_tile_n); + if (ok && (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0) { + ds4_gpu_mpp_compare_moe_mm("moe_down", + "moe_down", + down_type, + request_mid_f16, + cb, + &down_mm_args, + down_buf, + (NSUInteger)down_inner, + midbuf, + ds4_gpu_tensor_offset(mid), + down_dst, + down_dst_off, + (uint64_t)pair_rows * out_dim, + n_tokens, + (uint64_t)n_expert * out_dim, + expert_mid_dim); + } } else { ok = ds4_gpu_encode_mul_mv_id(cb, down_mv_pipeline, diff --git a/ds4_server.c b/ds4_server.c index bc8abbbd..33c434fd 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -7840,6 +7840,15 @@ static float parse_float_arg(const char *s, const char *opt, float minv, float m return v; } +static ds4_mpp_mode parse_mpp_mode_arg(const char *s) { + if (!strcmp(s, "auto")) return DS4_MPP_AUTO; + if (!strcmp(s, "on")) return DS4_MPP_ON; + if (!strcmp(s, "off")) return DS4_MPP_OFF; + server_log(DS4_LOG_DEFAULT, "ds4-server: invalid Metal Tensor mode: %s", s); + server_log(DS4_LOG_DEFAULT, "ds4-server: valid Metal Tensor modes are: auto, on, off"); + exit(2); +} + static const char *need_arg(int *i, int argc, char **argv, const char *opt) { if (*i + 1 >= argc) { server_log(DS4_LOG_DEFAULT, "ds4-server: missing value for %s", opt); @@ -7897,7 +7906,10 @@ static void usage(FILE *fp) { " -t, --threads N\n" " CPU helper threads for lightweight host-side work.\n" " --quality\n" - " Prefer exact kernels where faster approximate paths exist; MTP uses strict verification.\n" + " Prefer exact kernels where faster approximate paths exist; disables Metal Tensor routes; MTP uses strict verification.\n" + " -mt MODE, --mt MODE\n" + " Metal Tensor policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" + " Legacy alias: --mpp MODE.\n" " --dir-steering-file FILE\n" " Load one f32 direction vector per layer for directional steering.\n" " --dir-steering-ffn F\n" @@ -8020,6 +8032,8 @@ static server_config parse_options(int argc, char **argv) { c.default_tokens = parse_int_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) { c.engine.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg); + } else if (!strcmp(arg, "-mt") || !strcmp(arg, "--mt") || !strcmp(arg, "--mpp")) { + c.engine.mpp_mode = parse_mpp_mode_arg(need_arg(&i, argc, argv, arg)); } else if (!strcmp(arg, "--host")) { c.host = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--port")) { diff --git a/metal/dense.metal b/metal/dense.metal index a84927e9..6400c69d 100644 --- a/metal/dense.metal +++ b/metal/dense.metal @@ -910,6 +910,354 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +#ifdef DS4_METAL_HAS_TENSOR +template< + short NR0, short NR1, + typename SA, typename SA_4x4, typename block_q, short nl, + void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1> +kernel void kernel_mul_mm_mpp( + constant ds4_metal_args_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + const int im = tgpig.z; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + threadgroup SA *sa = (threadgroup SA *)shmem; + threadgroup SA *sb = sa + NR0*NK; + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + + device const T1 *ptrB = (device const T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11/sizeof(T1); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (!FC_mul_mm_bc_out || r0 + row < M) { + if (is_same::value && FC_mul_mm_bc_inp) { + device const T0 *row_ptr = (device const T0 *)(srcA + args.nb01*(r0 + row) + offset0); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? (SA)row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos/(16*nl); + const short il = (k_pos/16)%nl; + device const block_q *row_ptr = (device const block_q *)(srcA + args.nb01*(r0 + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (SA)0; + } + } + } + for (int work = tiitg; work < NK*NR1; work += NUM_THREADS) { + const int col = work/NK; + const int k = work%NK; + if ((!FC_mul_mm_bc_out && !FC_mul_mm_bc_inp) || + (r1 + col < N && loop_k + k < K)) { + sb[col*NK + k] = (SA)ptrB[(uint64_t)(r1 + col)*strideB + loop_k + k]; + } else { + sb[col*NK + k] = (SA)0; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(0, 0); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_batch = (device float *)dst + im*N*M; + if (!FC_mul_mm_bc_out) { + device float *dst_tile = dst_batch + r0 + (uint64_t)r1*M; + auto tD = tensor(dst_tile, dextents(NR0, NR1), array({1, M})); + cT.store(tD); + } else { + auto tD = tensor(dst_batch, dextents(M, N), array({1, M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); + } +} + +typedef decltype(kernel_mul_mm_mpp<64, 32, half, half4x4, float4x4, 1, dequantize_f32, float, float4x4, float>) mul_mm_mpp_t; +typedef decltype(kernel_mul_mm_mpp<64, 64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>) mul_mm_mpp_q8_n64_t; + +template [[host_name("kernel_mul_mm_f16_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp<64, 32, half, half4x4, half4x4, 1, dequantize_f16, half, half4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp<64, 32, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp_n64")]] kernel mul_mm_mpp_q8_n64_t kernel_mul_mm_mpp<64, 64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; + +kernel void kernel_mul_mm_f16_f32_pair_mpp( + constant ds4_metal_args_mul_mm & args, + device const char * srcA0, + device const char * srcA1, + device const char * srcB, + device char * dst0, + device char * dst1, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + const int im = tgpig.z; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + threadgroup half *sa0 = (threadgroup half *)shmem; + threadgroup half *sa1 = sa0 + NR0*NK; + threadgroup half *sb = sa1 + NR0*NK; + auto tA0 = tensor(sa0, dextents(NK, NR0)); + auto tA1 = tensor(sa1, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + + device const float *ptrB = (device const float *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11/sizeof(float); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto c0 = mm.template get_destination_cooperative_tensor(); + auto c1 = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < c0.get_capacity(); ++i) { + if (c0.is_valid_element(i)) { + c0[i] = 0.0f; + c1[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (!FC_mul_mm_bc_out || r0 + row < M) { + device const half *row0 = (device const half *)(srcA0 + args.nb01*(r0 + row) + offset0); + device const half *row1 = (device const half *)(srcA1 + args.nb01*(r0 + row) + offset0); + FOR_UNROLL (short i = 0; i < 16; i++) { + const bool in_bounds = k_pos + i < K; + sa0[row*NK + k_base + i] = in_bounds ? row0[k_pos + i] : (half)0; + sa1[row*NK + k_base + i] = in_bounds ? row1[k_pos + i] : (half)0; + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa0[row*NK + k_base + i] = (half)0; + sa1[row*NK + k_base + i] = (half)0; + } + } + } + for (int work = tiitg; work < NK*NR1; work += NUM_THREADS) { + const int col = work/NK; + const int k = work%NK; + if (!FC_mul_mm_bc_out || (r1 + col < N && loop_k + k < K)) { + sb[col*NK + k] = (half)ptrB[(uint64_t)(r1 + col)*strideB + loop_k + k]; + } else { + sb[col*NK + k] = (half)0; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA0 = tA0.slice(0, 0); + auto mA1 = tA1.slice(0, 0); + auto mB = tB.slice(0, 0); + mm.run(mB, mA0, c0); + mm.run(mB, mA1, c1); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst0_batch = (device float *)dst0 + im*N*M; + device float *dst1_batch = (device float *)dst1 + im*N*M; + if (!FC_mul_mm_bc_out) { + device float *dst0_tile = dst0_batch + r0 + (uint64_t)r1*M; + device float *dst1_tile = dst1_batch + r0 + (uint64_t)r1*M; + auto tD0 = tensor(dst0_tile, dextents(NR0, NR1), array({1, M})); + auto tD1 = tensor(dst1_tile, dextents(NR0, NR1), array({1, M})); + c0.store(tD0); + c1.store(tD1); + } else { + auto tD0 = tensor(dst0_batch, dextents(M, N), array({1, M})); + auto tD1 = tensor(dst1_batch, dextents(M, N), array({1, M})); + auto mD0 = tD0.slice(r0, r1); + auto mD1 = tD1.slice(r0, r1); + c0.store(mD0); + c1.store(mD1); + } +} + +template< + short NR1, + typename SA, typename SA_4x4, typename block_q, short nl, + void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1> +kernel void kernel_mul_mm_mpp_direct_rhs( + constant ds4_metal_args_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + const int im = tgpig.z; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + threadgroup SA *sa = (threadgroup SA *)shmem; + auto tA = tensor(sa, dextents(NK, NR0)); + + device T1 *ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11/sizeof(T1); + auto tB = tensor(ptrB, dextents(K, N), array({1, strideB})); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, true, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (r0 + row < M) { + if (is_same::value && FC_mul_mm_bc_inp) { + device const T0 *row_ptr = (device const T0 *)(srcA + args.nb01*(r0 + row) + offset0); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? (SA)row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos/(16*nl); + const short il = (k_pos/16)%nl; + device const block_q *row_ptr = (device const block_q *)(srcA + args.nb01*(r0 + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (SA)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, r1); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_batch = (device float *)dst + im*N*M; + auto tD = tensor(dst_batch, dextents(M, N), array({1, M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); +} + +typedef decltype(kernel_mul_mm_mpp_direct_rhs<32, half, half4x4, float4x4, 1, dequantize_f32, float, float4x4, float>) mul_mm_mpp_direct_rhs_t; +typedef decltype(kernel_mul_mm_mpp_direct_rhs<64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>) mul_mm_mpp_direct_rhs_q8_n64_t; + +template [[host_name("kernel_mul_mm_f16_f32_mpp_direct_rhs")]] kernel mul_mm_mpp_direct_rhs_t kernel_mul_mm_mpp_direct_rhs<32, half, half4x4, half4x4, 1, dequantize_f16, half, half4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp_direct_rhs")]] kernel mul_mm_mpp_direct_rhs_t kernel_mul_mm_mpp_direct_rhs<32, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp_direct_rhs_n64")]] kernel mul_mm_mpp_direct_rhs_q8_n64_t kernel_mul_mm_mpp_direct_rhs<64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; +#endif + // Tiled matrix-matrix kernel used for prompt batches larger than 8. DS4 uses // this to turn prefill into large simdgroup matrix operations; each block_q // contains 16*nl weights. @@ -1114,6 +1462,242 @@ kernel void kernel_mul_mm( } } +kernel void kernel_mul_mm_f16_f32_pair( + constant ds4_metal_args_mul_mm & args, + device const char * src0_a, + device const char * src0_b, + device const char * src1, + device char * dst_a, + device char * dst_b, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup half * sa_a = (threadgroup half *)(shmem); + threadgroup half * sa_b = (threadgroup half *)(shmem + 4096); + threadgroup half * sb = (threadgroup half *)(shmem + 8192); + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int im = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1; + + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; + + const short il0 = (tiitg % NL0); + short il = il0; + + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il0; + + device const half4x4 * xa = (device const half4x4 *)(src0_a + args.nb01*(r0 + lr0) + offset0) + offset1; + device const half4x4 * xb = (device const half4x4 *)(src0_b + args.nb01*(r0 + lr0) + offset0) + offset1; + + const short iy = 8*(tiitg % NL1); + + device const float * y = (device const float *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1 + lr1) + + args.nb10*iy); + + simdgroup_half8x8 ma[4]; + simdgroup_half8x8 mb[2]; + + simdgroup_float8x8 mc_a[8]; + simdgroup_float8x8 mc_b[8]; + + for (short i = 0; i < 8; i++) { + mc_a[i] = make_filled_simdgroup_matrix(0.f); + mc_b[i] = make_filled_simdgroup_matrix(0.f); + } + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + half4x4 temp_a; + half4x4 temp_b; + dequantize_f16(xa, il, temp_a); + dequantize_f16(xb, il, temp_b); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + + const short ib = 8*sx + sy; + + *(sa_a + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; + *(sa_b + 64*ib + 8*ly + lx) = temp_b[i/4][i%4]; + } + + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + + const short ib = 4*sx + sy; + + *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (half) *((device float *) y + i) : 0; + } + } else { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short ly = (tiitg/NL1)%8; + + const short ib = 4*sx + sy; + + *(threadgroup half2x4 *)(sb + 64*ib + 8*ly) = (half2x4)(*((device float2x4 *) y)); + } + + il = (il + 2 < 1) ? il + 2 : il % 2; + xa = (il < 2) ? xa + 2 : xa; + xb = (il < 2) ? xb + 2 : xb; + + y += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup const half * lsma_a = (sa_a + 4*64*(sgitg%2)); + threadgroup const half * lsma_b = (sa_b + 4*64*(sgitg%2)); + threadgroup const half * lsmb = (sb + 2*64*(sgitg/2)); + + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma_a + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++) { + simdgroup_multiply_accumulate(mc_a[i], mb[i/4], ma[i%4], mc_a[i]); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma_b + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++) { + simdgroup_multiply_accumulate(mc_b[i], mb[i/4], ma[i%4], mc_b[i]); + } + + lsma_a += 8*64; + lsma_b += 8*64; + lsmb += 4*64; + } + } + + if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { + device float * C_a = (device float *) dst_a + + (r0 + 32*(sgitg & 1)) + + (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + device float * C_b = (device float *) dst_b + + (r0 + 32*(sgitg & 1)) + + (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc_a[i], C_a + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); + simdgroup_store(mc_b[i], C_b + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); + } + } else { + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float * temp_str = (threadgroup float *) shmem; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc_a[i], + temp_str + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0 + 8*(i%4) + 8*NR0*(i/4), + NR0, + 0, + false); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < nr1; j += NR1) { + device float * D = (device float *) dst_a + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*NR0); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < nr0/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < nr0; i++) { + *(D + i) = *(C + i); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc_b[i], + temp_str + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0 + 8*(i%4) + 8*NR0*(i/4), + NR0, + 0, + false); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < nr1; j += NR1) { + device float * D = (device float *) dst_b + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*NR0); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < nr0/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < nr0; i++) { + *(D + i) = *(C + i); + } + } + } + } +} + typedef decltype(kernel_mul_mm) mul_mm_t; // Host-visible prefill matmul variants for F16 and Q8_0 weights. diff --git a/metal/dsv4_hc.metal b/metal/dsv4_hc.metal index 89cf6c65..49636f54 100644 --- a/metal/dsv4_hc.metal +++ b/metal/dsv4_hc.metal @@ -77,6 +77,24 @@ struct ds4_metal_args_dsv4_hc_expand { int32_t has_add; }; +// Numerically stable sigmoid. The naive form 1/(1+exp(-z)) overflows for large +// negative z (exp(-z) blows up); replacing it with the 0.5*(tanh(z/2)+1) identity +// keeps the value bounded in [0, 1] across the entire float range. Gated by +// DS4_METAL_HC_STABLE so we can A/B vs the historical form on M5 Max where the +// faster ALU is more likely to push HC mixer inputs into the unstable regime. +#ifdef DS4_METAL_HC_STABLE +static inline float ds4_hc_sigmoid(float z) { return 0.5f * tanh(0.5f * z) + 0.5f; } +static inline float4 ds4_hc_sigmoid(float4 z) { return 0.5f * tanh(0.5f * z) + 0.5f; } +// 2 * sigmoid(z) == 1 + tanh(z/2). +static inline float ds4_hc_twice_sigmoid(float z) { return 1.0f + tanh(0.5f * z); } +static inline float4 ds4_hc_twice_sigmoid(float4 z) { return 1.0f + tanh(0.5f * z); } +#else +static inline float ds4_hc_sigmoid(float z) { return 1.0f / (1.0f + exp(-z)); } +static inline float4 ds4_hc_sigmoid(float4 z) { return 1.0f / (1.0f + exp(-z)); } +static inline float ds4_hc_twice_sigmoid(float z) { return 2.0f / (1.0f + exp(-z)); } +static inline float4 ds4_hc_twice_sigmoid(float4 z) { return 2.0f / (1.0f + exp(-z)); } +#endif + // Splits an HC mixer row into pre weights, post gates, and the HC-to-HC // combination matrix. The 4-channel path is specialized because DS4 Flash uses // HC=4 in normal inference, while the scalar fallback keeps diagnostics usable. @@ -109,12 +127,12 @@ kernel void kernel_dsv4_hc_split_sinkhorn( const float4 pre_z = *((device const float4 *) mix) * pre_scale + *((device const float4 *) base); - *((device float4 *) out) = 1.0f / (1.0f + exp(-pre_z)) + epsv; + *((device float4 *) out) = ds4_hc_sigmoid(pre_z) + epsv; const float4 post_z = *((device const float4 *) (mix + 4)) * post_scale + *((device const float4 *) (base + 4)); - *((device float4 *) (out + 4)) = 2.0f / (1.0f + exp(-post_z)); + *((device float4 *) (out + 4)) = ds4_hc_twice_sigmoid(post_z); float4 r0 = *((device const float4 *) (mix + 8)) * comb_scale + @@ -172,13 +190,13 @@ kernel void kernel_dsv4_hc_split_sinkhorn( for (int i = 0; i < HC; ++i) { const float z = mix[i] * pre_scale + base[i]; - out[i] = 1.0f / (1.0f + exp(-z)) + epsv; + out[i] = ds4_hc_sigmoid(z) + epsv; } for (int i = 0; i < HC; ++i) { const int off = HC + i; const float z = mix[off] * post_scale + base[off]; - out[off] = 2.0f / (1.0f + exp(-z)); + out[off] = ds4_hc_twice_sigmoid(z); } float c[HC_MAX*HC_MAX]; diff --git a/metal/dsv4_kv.metal b/metal/dsv4_kv.metal index 89bd7d3a..be760514 100644 --- a/metal/dsv4_kv.metal +++ b/metal/dsv4_kv.metal @@ -167,13 +167,25 @@ kernel void kernel_dsv4_kv_fp8_store_f32( if (off + (int)tid < n_nope) { const float q = dsv4_e4m3fn_dequant(clamp(v / fp8_scale, -448.0f, 448.0f)) * fp8_scale; kv[off + tid] = q; + // Diagnostic only: skip the FP16 round-trip that normally matches the + // half-typed FlashAttention KV buffer's precision. With this enabled the + // indexer will see higher-precision raw values than FlashAttention does, + // which is informative but not a production-ready setting. +#ifdef DS4_METAL_KV_RAW_F32 + raw[off + tid] = q; +#else raw[off + tid] = (float)((half)q); +#endif } threadgroup_barrier(mem_flags::mem_threadgroup); } for (int i = n_nope + tid; i < head_dim; i += 64) { +#ifdef DS4_METAL_KV_RAW_F32 + raw[i] = kv[i]; +#else raw[i] = (float)((half)kv[i]); +#endif } } diff --git a/metal/dsv4_misc.metal b/metal/dsv4_misc.metal index b06d29d3..c9dc09c6 100644 --- a/metal/dsv4_misc.metal +++ b/metal/dsv4_misc.metal @@ -594,9 +594,7 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8( // Decode specialization of kernel_dsv4_indexed_mixed_attention_heads8. // Generation attends one token at a time, so the ratio-4 indexed path spends a // visible amount of time repeatedly staging the same K/V row for the eight -// heads in a group. This variant stages four selected rows at once and then -// consumes them sequentially, preserving the row order and online softmax math -// while cutting threadgroup barriers in the long top-k scan. +// heads in a group. This diagnostic variant stages four selected rows at once. kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb4( constant ds4_metal_args_dsv4_indexed_attention & args, device const char *q, @@ -720,6 +718,135 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb4( dst4[lane + 96] = o3 * inv_s; } +// Decode specialization of kernel_dsv4_indexed_mixed_attention_heads8. +// Generation attends one token at a time, so the ratio-4 indexed path spends a +// visible amount of time repeatedly staging the same K/V row for the eight +// heads in a group. This variant stages sixteen selected rows at once and then +// consumes them sequentially, preserving the row order and online softmax math +// while cutting threadgroup barriers in the long top-k scan. +kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb16( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + const uint token = tgpig.x; + const uint head = tgpig.y * 8u + (uint)sg; + if (token >= args.n_tokens || head >= args.n_head) { + return; + } + + device const float4 *q4 = (device const float4 *)(q + + (uint64_t)token * args.q_token_stride + + (uint64_t)head * args.q_head_stride); + const half4 q0 = (half4)q4[lane + 0]; + const half4 q1 = (half4)q4[lane + 32]; + const half4 q2 = (half4)q4[lane + 64]; + const half4 q3 = (half4)q4[lane + 96]; + + float M = -FLT_MAX/2.0f; + float S = 0.0f; + float4 o0 = 0.0f; + float4 o1 = 0.0f; + float4 o2 = 0.0f; + float4 o3 = 0.0f; + + const uint qpos = args.pos0 + token; + const uint last_pos = args.pos0 + args.n_tokens - 1u; + const uint first_raw_pos = last_pos + 1u - args.n_raw; + const uint raw_last_pos = first_raw_pos + args.n_raw - 1u; + const uint window_first = (args.window != 0u && qpos + 1u > args.window) ? + qpos + 1u - args.window : 0u; + uint first = max(first_raw_pos, window_first); + uint last = min(qpos, raw_last_pos); + + if (first <= last) { + for (uint pos0 = first; pos0 <= last; pos0 += 16u) { + const uint n_rows = min(16u, last - pos0 + 1u); + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + const uint logical = pos0 + r - first_raw_pos; + const uint row = (args.raw_start + logical) % args.raw_cap; + device const float4 *src = (device const float4 *)(raw_kv + + (uint64_t)row * args.raw_row_stride); + kv_shared[off] = src[c]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + uint visible = (qpos + 1u) / args.ratio; + visible = min(visible, args.n_comp); + device const int32_t *row_topk = (device const int32_t *)(topk + + (uint64_t)token * args.topk_token_stride); + bool stop = false; + for (uint i = 0; i < args.top_k && !stop; i += 16u) { + uint rows[16]; + uint n_rows = 0; + for (uint j = 0; j < 16u && i + j < args.top_k; j++) { + const int32_t idx = row_topk[i + j]; + if (idx < 0) { + continue; + } + if ((uint)idx >= visible) { + stop = true; + break; + } + rows[n_rows++] = (uint)idx; + } + if (n_rows == 0) { + continue; + } + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + device const float4 *src = (device const float4 *)(comp_kv + + (uint64_t)rows[r] * args.comp_row_stride); + kv_shared[off] = src[c]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + dsv4_attend_sink(((device const float *)sinks)[head], M, S, o0, o1, o2, o3); + + const float inv_s = S == 0.0f ? 0.0f : 1.0f/S; + device float4 *dst4 = (device float4 *)(dst + + (uint64_t)token * args.dst_token_stride + + (uint64_t)head * args.dst_head_stride); + dst4[lane + 0] = o0 * inv_s; + dst4[lane + 32] = o1 * inv_s; + dst4[lane + 64] = o2 * inv_s; + dst4[lane + 96] = o3 * inv_s; +} + static inline float dsv4_indexer_dot128_shared_q( float4 c0, float4 c1, diff --git a/metal/dsv4_rope.metal b/metal/dsv4_rope.metal index aaa6f3d9..b3207561 100644 --- a/metal/dsv4_rope.metal +++ b/metal/dsv4_rope.metal @@ -110,7 +110,13 @@ kernel void kernel_dsv4_rope_tail_f32( const int ic = r; const int rel_i0 = 2*ic; +#ifdef DS4_METAL_ROPE_EXP2_LOG2 + // Equivalent to pow(freq_base, k) but expressed through IEEE-754 + // primitives that have tighter precision guarantees than Metal's pow(). + const float theta = theta_base * exp2(inv_ndims * (float)rel_i0 * log2(args.freq_base)); +#else const float theta = theta_base * pow(args.freq_base, inv_ndims*rel_i0); +#endif const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; float cos_theta; @@ -133,7 +139,11 @@ kernel void kernel_dsv4_rope_tail_f32( } const int ic = r/2; +#ifdef DS4_METAL_ROPE_EXP2_LOG2 + const float theta = theta_base * exp2(inv_ndims * (float)r * log2(args.freq_base)); +#else const float theta = theta_base * pow(args.freq_base, inv_ndims*r); +#endif const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; float cos_theta; diff --git a/metal/moe.metal b/metal/moe.metal index 65074d7d..4619de28 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -87,6 +87,8 @@ static constant ulong ds4_metal_iq2xxs_grid[256] = { 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, }; +constant bool FC_mul_mm_id_mpp [[function_constant(FC_MUL_MM + 2)]]; + #define kmask_iq2xs ds4_metal_kmask_iq2xs #define ksigns_iq2xs ds4_metal_ksigns_iq2xs #define iq2xxs_grid ds4_metal_iq2xxs_grid @@ -121,6 +123,13 @@ struct ds4_metal_dsv4_moe_swiglu_weight_args { float clamp_value; }; +struct ds4_metal_dsv4_moe_sum6_args { + uint32_t width; + uint32_t tokens; + uint64_t src_token_stride; + uint64_t dst_token_stride; +}; + // Routed-MoE activation for the selected experts: // clamp(gate), clamp(up), silu(gate) * up * route_weight. Normal inference // does not consume gate/up after this point, so the fast path avoids writing the @@ -198,6 +207,31 @@ kernel void kernel_dsv4_moe_swiglu_weight_f16( } } +kernel void kernel_dsv4_moe_sum6_f32( + constant ds4_metal_dsv4_moe_sum6_args &args, + device const char *src, + device char *dst, + uint token[[threadgroup_position_in_grid]], + uint tid[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + if (token >= args.tokens) return; + + device const float *s = + (device const float *)(src + (uint64_t)token * args.src_token_stride); + device float *d = + (device float *)(dst + (uint64_t)token * args.dst_token_stride); + + for (uint col = tid; col < args.width; col += ntg) { + float v = s[col]; + v += s[args.width + col]; + v += s[2u * args.width + col]; + v += s[3u * args.width + col]; + v += s[4u * args.width + col]; + v += s[5u * args.width + col]; + d[col] = v; + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -1515,7 +1549,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_ // Batched routed-expert matmul. It reads the expert-major map produced above, // loads selected expert weights, and writes results back to token-major slots // so the DS4 FFN can apply SwiGLU, weighting, and the down projection. -template +template kernel void kernel_mul_mm_id( constant ds4_metal_args_mul_mm_id & args, device const char * src0, @@ -1530,9 +1564,11 @@ kernel void kernel_mul_mm_id( ushort sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef DS4_METAL_HAS_TENSOR + threadgroup float *sc = (threadgroup float *)shmem; +#endif constexpr int NR0 = 64; - constexpr int NR1 = 32; constexpr int NK = 32; constexpr int NL0 = NK/16; @@ -1553,6 +1589,7 @@ kernel void kernel_mul_mm_id( const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; + const bool full_mpp_tile = nr0 == NR0 && nr1 == NR1 && (args.ne00 % NK) == 0; const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; @@ -1588,6 +1625,24 @@ kernel void kernel_mul_mm_id( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } +#ifdef DS4_METAL_HAS_TENSOR + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } +#endif for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { if (is_same::value && FC_mul_mm_bc_inp) { @@ -1597,12 +1652,23 @@ kernel void kernel_mul_mm_id( const short sx = 2*il0 + i/8; const short sy = (tiitg/NL0)/8; +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = + full_mpp_tile || loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } else +#endif + { const short lx = (tiitg/NL0)%8; const short ly = i%8; const short ib = 8*sx + sy; *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } } } else { S0_4x4 temp_a; @@ -1614,16 +1680,52 @@ kernel void kernel_mul_mm_id( const short sx = 2*il0 + i/8; const short sy = (tiitg/NL0)/8; +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; + } else +#endif + { const short lx = (tiitg/NL0)%8; const short ly = i%8; const short ib = 8*sx + sy; *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; + } } } if (FC_mul_mm_bc_inp) { +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + for (short tile_row = 0; tile_row < NR1; tile_row += 32) { + const short t = (short)tiitg + tile_row*4; + const short row = t/NL1; + const short sx = t%NL1; + const short sy = row/8; + const short lx = 0; + const short ly = row%8; + const int idb = (full_mpp_tile || row < nr1) ? ids_i32[im*args.ne21 + r1 + row] : 0; + const short i11b = (idb % args.ne20) % args.ne11; + const short i12b = (idb / args.ne20); + device const T1 *yb = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12b + + args.nb11*i11b + + args.nb10*(loop_k + 8*sx)); + + FOR_UNROLL (short i = 0; i < 8; ++i) { + *(sb + NK*(8*sy + ly) + 8*sx + lx + i) = + full_mpp_tile || (row < nr1 && loop_k + 8*sx + i < args.ne00) ? (S1) *(yb + i) : 0; + } + } + } else +#endif + { for (short i = 0; i < 8; ++i) { const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; @@ -1635,7 +1737,35 @@ kernel void kernel_mul_mm_id( *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; } + } } else { +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + for (short tile_row = 0; tile_row < NR1; tile_row += 32) { + const short t = (short)tiitg + tile_row*4; + const short row = t/NL1; + const short sx = t%NL1; + const short sy = row/8; + const short ly = row%8; + const int idb = (full_mpp_tile || row < nr1) ? ids_i32[im*args.ne21 + r1 + row] : 0; + const short i11b = (idb % args.ne20) % args.ne11; + const short i12b = (idb / args.ne20); + device const T1 *yb = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12b + + args.nb11*i11b + + args.nb10*loop_k); + + if (full_mpp_tile || row < nr1) { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = + (S1_2x4)(*((device T1_2x4 *) yb + sx)); + } else { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(0); + } + } + } else +#endif + { const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; @@ -1644,6 +1774,7 @@ kernel void kernel_mul_mm_id( const short ib = 4*sx + sy; *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); + } } il = (il + 2 < nl) ? il + 2 : il % 2; @@ -1653,6 +1784,14 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + mm.run(sB, sA, cT); + } else +#endif + { threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); @@ -1678,15 +1817,24 @@ kernel void kernel_mul_mm_id( lsma += 8*64; lsmb += 4*64; } + } } threadgroup_barrier(mem_flags::mem_threadgroup); +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + auto tC = tensor(sc, dextents(NR0, NR1)); + cT.store(tC); + } else +#endif + { threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } + } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -1714,18 +1862,611 @@ kernel void kernel_mul_mm_id( } } -typedef decltype(kernel_mul_mm_id) mul_mm_id; -typedef decltype(kernel_mul_mm_id) mul_mm_id_f16_rhs; +#ifdef DS4_METAL_HAS_TENSOR +template +kernel void kernel_mul_mm_id_pair_mpp( + constant ds4_metal_args_mul_mm_id & args, + device const char * src0_gate, + device const char * src0_up, + device const char * src1, + device const char * htpe, + device const char * hids, + device char * dst_gate, + device char * dst_up, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); + threadgroup float *sc = (threadgroup float *)shmem; + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int im = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); + const int32_t neh1 = tpe_u32[im]; + if (r1 >= neh1) { + return; + } + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short il0 = (tiitg % NL0); + short il = il0; + + const int i13 = 0; + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; + const short offset1 = il0/nl; + device const block_q * x_gate = + (device const block_q *)(src0_gate + args.nb01*(r0 + lr0) + offset0) + offset1; + device const block_q * x_up = + (device const block_q *)(src0_up + args.nb01*(r0 + lr0) + offset0) + offset1; + + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cGate = mm.template get_destination_cooperative_tensor(); + auto cUp = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cGate.get_capacity(); ++i) { + if (cGate.is_valid_element(i)) cGate[i] = 0.0f; + if (cUp.is_valid_element(i)) cUp[i] = 0.0f; + } + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + S0_4x4 temp_gate; + dequantize_func(x_gate, il, temp_gate); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_gate[i/4][i%4]; + } + + const short row = ((short)tiitg)/NL1; + const short sx = ((short)tiitg)%NL1; + const short sy = row/8; + const short ly = row%8; + const int idb = row < nr1 ? ids_i32[im*args.ne21 + r1 + row] : 0; + const short i11b = (idb % args.ne20) % args.ne11; + const short i12b = (idb / args.ne20); + device const T1 *yb = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12b + + args.nb11*i11b + + args.nb10*loop_k); + + if (row < nr1) { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = + (S1_2x4)(*((device T1_2x4 *) yb + sx)); + } else { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(0); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + mm.run(sB, sA, cGate); + + S0_4x4 temp_up; + dequantize_func(x_up, il, temp_up); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short ax = 2*il0 + i/8; + const short ay = (tiitg/NL0)/8; + const short lx = i%8; + const short ly2 = (tiitg/NL0)%8; + *(sa + NK*(8*ay + ly2) + 8*ax + lx) = temp_up[i/4][i%4]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sA = tA.slice(0, 0); + sB = tB.slice(0, 0); + mm.run(sB, sA, cUp); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x_gate = (il < 2) ? x_gate + (2 + nl - 1)/nl : x_gate; + x_up = (il < 2) ? x_up + (2 + nl - 1)/nl : x_up; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto tC = tensor(sc, dextents(NR0, NR1)); + cGate.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < nr1; j += 4) { + const int idj = ids_i32[im*args.ne21 + r1 + j]; + const short ide = idj % args.ne20; + const short idt = idj / args.ne20; + device float * D = (device float *) dst_gate + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + threadgroup float * C = (threadgroup float *) shmem + j*NR0; + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < nr0/4; i += 32) *(D4 + i) = *(C4 + i); + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) *(D + i) = *(C + i); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + cUp.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < nr1; j += 4) { + const int idj = ids_i32[im*args.ne21 + r1 + j]; + const short ide = idj % args.ne20; + const short idt = idj / args.ne20; + device float * D = (device float *) dst_up + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + threadgroup float * C = (threadgroup float *) shmem + j*NR0; + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < nr0/4; i += 32) *(D4 + i) = *(C4 + i); + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) *(D + i) = *(C + i); + } +} +#endif + +typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>) mul_mm_id; +typedef decltype(kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>) mul_mm_id_n64; +typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>) mul_mm_id_f16_rhs; +typedef decltype(kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>) mul_mm_id_f16_rhs_n64; + +#ifdef DS4_METAL_HAS_TENSOR +// Faster routed-MoE MPP tensor layout from the first Metal 4 PR. The host keeps +// it inside the active route windows that pass full-model checks. +template +kernel void kernel_mul_mm_id_mpp_fast_layout( + constant ds4_metal_args_mul_mm_id & args, + device const char * src0, + device const char * src1, + device const char * htpe, + device const char * hids, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + (void)sgitg; + + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); + threadgroup float *sc = (threadgroup float *)shmem; + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int im = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); + + const int32_t neh1 = tpe_u32[im]; + + if (r1 >= neh1) { + return; + } + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; + + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; + + const short il0 = (tiitg % NL0); + short il = il0; + + const int id = ids_i32[im*args.ne21 + r1 + lr1]; + + const short i11 = (id % args.ne20) % args.ne11; + const short i12 = (id / args.ne20); + const short i13 = 0; + + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; + const short offset1 = il0/nl; + + device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; + + const short iy = 8*(tiitg % NL1); + + device const T1 * y = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*iy); + + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NR1, NK)); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = + loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; + } + } + + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + const short lx = i; + const short ly = (tiitg/NL1)%8; + + *(sb + NK*(8*sy + ly) + 8*sx + lx) = + loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } + } else { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + const short ly = (tiitg/NL1)%8; + + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = + (S1_2x4)(*((device T1_2x4 *) y)); + } + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; + + y += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + mm.run(sB, sA, cT); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto tC = tensor(sc, dextents(NR0, NR1)); + cT.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = tiitg/32; j < nr1; j += 4) { + const int idj = ids_i32[im*args.ne21 + r1 + j]; + + const short ide = idj % args.ne20; + const short idt = idj / args.ne20; + + device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = (threadgroup float *) shmem + j*NR0; + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < nr0/4; i += 32) { + *(D4 + i) = *(C4 + i); + } + + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) { + *(D + i) = *(C + i); + } + } +} + +typedef decltype(kernel_mul_mm_id_mpp_fast_layout) mul_mm_id_fast_layout; +typedef decltype(kernel_mul_mm_id_mpp_fast_layout) mul_mm_id_fast_layout_f16_rhs; +typedef decltype(kernel_mul_mm_id_pair_mpp) mul_mm_id_pair_mpp_t; +#endif // Host-visible batched MoE matmul variants for the DS4 quant formats. -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q8_0_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q8_0_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +#ifdef DS4_METAL_HAS_TENSOR +template [[host_name("kernel_mul_mm_id_q8_0_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q2_K_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q4_K_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q8_0_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q2_K_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q4_K_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; + +template [[host_name("kernel_mul_mm_id_q8_0_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +template [[host_name("kernel_mul_mm_id_q2_K_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +template [[host_name("kernel_mul_mm_id_q4_K_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +#endif + +#ifdef DS4_METAL_HAS_TENSOR +template +kernel void kernel_attn_out_low_q8_0_mpp( + constant ds4_metal_args_mul_mm_id & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne21; + const int G = args.ne1; + const int group = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + const bool full_tile = r0 + NR0 <= M && r1 + NR1 <= N && (K % NK) == 0; + + threadgroup half *sa = (threadgroup half *)shmem; + threadgroup half *sb = sa + NR0*NK; + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + + device const float *ptrB = (device const float *)(srcB + args.nb11*group); + const int strideB = args.nb12/sizeof(float); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (full_tile || r0 + row < M) { + const int block_idx = k_pos/32; + const short il = (k_pos/16)%2; + device const block_q8_0 *row_ptr = + (device const block_q8_0 *)(srcA + args.nb01*(r0 + row) + group*args.nb02); + + half4x4 temp_a; + dequantize_q8_0(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (full_tile || k_pos + i < K) ? temp_a[i/4][i%4] : (half)0; + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (half)0; + } + } + } + for (int work = tiitg; work < NK*NR1; work += NUM_THREADS) { + const int col = work/NK; + const int k = work%NK; + if (full_tile || (r1 + col < N && loop_k + k < K)) { + sb[col*NK + k] = (half)ptrB[(uint64_t)(r1 + col)*strideB + loop_k + k]; + } else { + sb[col*NK + k] = (half)0; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(0, 0); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_group = (device float *)dst + group*M; + if (full_tile) { + device float *dst_tile = dst_group + r0 + (uint64_t)r1*G*M; + auto tD = tensor(dst_tile, dextents(NR0, NR1), array({1, G*M})); + cT.store(tD); + } else { + auto tD = tensor(dst_group, dextents(M, N), array({1, G*M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); + } +} + +typedef decltype(kernel_attn_out_low_q8_0_mpp<32>) attn_out_low_q8_0_mpp_t; + +template [[host_name("kernel_attn_out_low_q8_0_mpp")]] kernel attn_out_low_q8_0_mpp_t kernel_attn_out_low_q8_0_mpp<32>; +template [[host_name("kernel_attn_out_low_q8_0_mpp_n64")]] kernel attn_out_low_q8_0_mpp_t kernel_attn_out_low_q8_0_mpp<64>; + +template +kernel void kernel_attn_out_low_q8_0_mpp_direct_rhs( + constant ds4_metal_args_mul_mm_id & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne21; + const int G = args.ne1; + const int group = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + const bool full_tile = r0 + NR0 <= M && r1 + NR1 <= N && (K % NK) == 0; + + threadgroup half *sa = (threadgroup half *)shmem; + auto tA = tensor(sa, dextents(NK, NR0)); + + device float *ptrB = (device float *)(srcB + args.nb11*group); + const int strideB = args.nb12/sizeof(float); + auto tB = tensor(ptrB, dextents(K, N), array({1, strideB})); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, true, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (full_tile || r0 + row < M) { + const int block_idx = k_pos/32; + const short il = (k_pos/16)%2; + device const block_q8_0 *row_ptr = + (device const block_q8_0 *)(srcA + args.nb01*(r0 + row) + group*args.nb02); + + half4x4 temp_a; + dequantize_q8_0(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (full_tile || k_pos + i < K) ? temp_a[i/4][i%4] : (half)0; + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (half)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, r1); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_group = (device float *)dst + group*M; + if (full_tile) { + device float *dst_tile = dst_group + r0 + (uint64_t)r1*G*M; + auto tD = tensor(dst_tile, dextents(NR0, NR1), array({1, G*M})); + cT.store(tD); + } else { + auto tD = tensor(dst_group, dextents(M, N), array({1, G*M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); + } +} + +typedef decltype(kernel_attn_out_low_q8_0_mpp_direct_rhs<32>) attn_out_low_q8_0_mpp_direct_rhs_t; +typedef decltype(kernel_attn_out_low_q8_0_mpp_direct_rhs<64>) attn_out_low_q8_0_mpp_direct_rhs_n64_t; + +template [[host_name("kernel_attn_out_low_q8_0_mpp_direct_rhs")]] kernel attn_out_low_q8_0_mpp_direct_rhs_t kernel_attn_out_low_q8_0_mpp_direct_rhs<32>; +template [[host_name("kernel_attn_out_low_q8_0_mpp_direct_rhs_n64")]] kernel attn_out_low_q8_0_mpp_direct_rhs_n64_t kernel_attn_out_low_q8_0_mpp_direct_rhs<64>; + +#endif #undef QK_NL #undef kmask_iq2xs diff --git a/metal/norm.metal b/metal/norm.metal index 5bc97179..89206704 100644 --- a/metal/norm.metal +++ b/metal/norm.metal @@ -145,7 +145,14 @@ kernel void kernel_dsv4_qkv_rms_norm_f32_4( sumf = shmem_f32[tiisg]; sumf = simd_sum(sumf); +#ifdef DS4_METAL_NORM_RSQRT_DISABLE + // Match the formula used by kernel_rms_norm_fuse_impl above so both RMSNorm + // entry points produce bit-identical scales. Hardware rsqrt() and 1.0f/sqrt() + // can differ by ~1 ULP and that difference compounds across 43 layers. + const float scale = 1.0f / sqrt(sumf / float(n) + args.eps); +#else const float scale = rsqrt(sumf / float(n) + args.eps); +#endif for (int i = tpitg.x; i < n4; i += ntg.x) { y[i] = (x[i] * scale) * w[i]; diff --git a/speed-bench/compare_bench.py b/speed-bench/compare_bench.py new file mode 100755 index 00000000..034ab193 --- /dev/null +++ b/speed-bench/compare_bench.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +"""Plot two or more ds4-bench CSV runs as a speed comparison chart.""" + +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +REQUIRED_COLUMNS = { + "ctx_tokens", + "prefill_tps", + "gen_tps", +} + + +def read_run(path: Path) -> dict[int, dict[str, float]]: + with path.open(newline="") as fp: + reader = csv.DictReader(fp) + if reader.fieldnames is None: + raise SystemExit(f"{path}: empty CSV") + missing = REQUIRED_COLUMNS - set(reader.fieldnames) + if missing: + raise SystemExit(f"{path}: missing columns: {', '.join(sorted(missing))}") + + rows: dict[int, dict[str, float]] = {} + for row in reader: + ctx = int(row["ctx_tokens"]) + rows[ctx] = { + "prefill_tps": float(row["prefill_tps"]), + "gen_tps": float(row["gen_tps"]), + } + if not rows: + raise SystemExit(f"{path}: no data rows") + return rows + + +def context_label(ctx: int) -> str: + if ctx < 1024: + return f"{ctx / 1024:g}k" + rounded_k = round(ctx / 1024) + if abs(ctx - rounded_k * 1024) <= max(4, ctx * 0.001): + return f"{rounded_k}k" + return f"{ctx / 1024:.1f}k" + + +def annotate_points(ax, xs: list[int], ys: list[float], color: str, dy: float) -> None: + for x, y in zip(xs, ys): + ax.annotate( + f"{y:.1f}", + (x, y), + textcoords="offset points", + xytext=(0, dy), + ha="center", + va="bottom" if dy >= 0 else "top", + fontsize=8, + color=color, + fontweight="medium", + ) + + +def plot_metric( + ax, + xs: list[int], + labels: list[str], + series: list[list[float]], + metric_title: str, + run_labels: list[str], + annotate: bool, +) -> None: + colors = ["#2563eb", "#64748b", "#ea580c", "#16a34a", "#9333ea", "#dc2626"] + markers = ["o", "s", "^", "D", "P", "X"] + + for i, (values, label) in enumerate(zip(series, run_labels)): + color = colors[i % len(colors)] + ax.plot( + xs, + values, + marker=markers[i % len(markers)], + markersize=7, + linewidth=2.4, + color=color, + label=label, + ) + + if len(series) == 2: + ax.fill_between(xs, series[0], series[1], color=colors[1], alpha=0.08) + + ax.set_title(metric_title, fontsize=15, fontweight="bold", pad=12) + ax.set_xlabel("Context Size") + ax.set_ylabel("Tokens/sec") + ax.set_xticks(xs, labels) + ax.grid(True, color="#d1d5db", linewidth=0.9, alpha=0.65) + ax.set_axisbelow(True) + ax.margins(x=0.05, y=0.18) + + for spine in ("top", "right"): + ax.spines[spine].set_visible(False) + ax.spines["left"].set_color("#9ca3af") + ax.spines["bottom"].set_color("#9ca3af") + + if len(series) == 2: + gain_color = "#14532d" + ymin, ymax = ax.get_ylim() + label_y = ymin + (ymax - ymin) * 0.05 + for x, b, a in zip(xs, series[0], series[1]): + gain = ((a / b) - 1.0) * 100.0 if b else 0.0 + ax.annotate( + f"{gain:+.0f}%", + (x, label_y), + ha="center", + va="center", + fontsize=8, + color=gain_color if gain >= 0 else "#991b1b", + bbox={ + "boxstyle": "round,pad=0.24", + "facecolor": "#ecfdf5" if gain >= 0 else "#fef2f2", + "edgecolor": "#bbf7d0" if gain >= 0 else "#fecaca", + "linewidth": 0.8, + }, + ) + + if annotate: + offsets = [-16, 8, 22, 36, 50, 64] + for i, values in enumerate(series): + annotate_points(ax, xs, values, colors[i % len(colors)], offsets[i % len(offsets)]) + + +def default_run_labels(paths: list[Path], args: argparse.Namespace) -> list[str]: + if len(paths) == 2 and not args.labels: + return [args.before_label, args.after_label] + if args.labels: + if len(args.labels) != len(paths): + raise SystemExit("--labels count must match the number of CSV runs") + return args.labels + return [path.stem for path in paths] + + +def build_chart(args: argparse.Namespace) -> None: + if len(args.runs) < 2: + raise SystemExit("provide at least two ds4-bench CSV files") + runs = [read_run(path) for path in args.runs] + run_labels = default_run_labels(args.runs, args) + contexts = sorted(set.intersection(*(set(run) for run in runs))) + if not contexts: + raise SystemExit("the CSV files have no shared ctx_tokens values") + + x_positions = list(range(len(contexts))) + labels = [context_label(ctx) for ctx in contexts] + prefill_series = [[run[ctx]["prefill_tps"] for ctx in contexts] for run in runs] + gen_series = [[run[ctx]["gen_tps"] for ctx in contexts] for run in runs] + + plt.rcParams.update( + { + "figure.facecolor": "#f8fafc", + "axes.facecolor": "#ffffff", + "axes.edgecolor": "#cbd5e1", + "axes.labelcolor": "#111827", + "xtick.color": "#111827", + "ytick.color": "#111827", + "font.family": "DejaVu Sans", + } + ) + + fig, axes = plt.subplots(1, 2, figsize=(15.5, 7), constrained_layout=True) + fig.suptitle(args.title, fontsize=22, fontweight="bold", y=1.04) + + plot_metric( + axes[0], + x_positions, + labels, + prefill_series, + "Prompt Processing Speed", + run_labels, + not args.no_values, + ) + plot_metric( + axes[1], + x_positions, + labels, + gen_series, + "Text Generation Speed", + run_labels, + not args.no_values, + ) + + handles, legend_labels = axes[0].get_legend_handles_labels() + fig.legend( + handles, + legend_labels, + loc="upper center", + bbox_to_anchor=(0.5, 0.98), + ncol=min(len(run_labels), 4), + frameon=True, + fancybox=True, + shadow=False, + facecolor="#ffffff", + edgecolor="#cbd5e1", + ) + + output = args.output + if output.suffix.lower() != ".png": + raise SystemExit(f"{output}: output must be a .png file") + output.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output, dpi=180, bbox_inches="tight", format="png") + plt.close(fig) + + print(f"Wrote {output}") + header = ["ctx"] + for label in run_labels: + safe = label.lower().replace(" ", "_") + header.extend([f"prefill_{safe}", f"gen_{safe}"]) + for label in run_labels[1:]: + safe = label.lower().replace(" ", "_") + base = run_labels[0].lower().replace(" ", "_") + header.extend([f"prefill_gain_{safe}_vs_{base}", f"gen_gain_{safe}_vs_{base}"]) + print(",".join(header)) + for idx, ctx in enumerate(contexts): + row = [str(ctx)] + base_prefill = prefill_series[0][idx] + base_gen = gen_series[0][idx] + for prefill, gen in zip(prefill_series, gen_series): + row.extend([f"{prefill[idx]:.2f}", f"{gen[idx]:.2f}"]) + for prefill, gen in zip(prefill_series[1:], gen_series[1:]): + prefill_gain = ((prefill[idx] / base_prefill) - 1.0) * 100.0 if base_prefill else 0.0 + gen_gain = ((gen[idx] / base_gen) - 1.0) * 100.0 if base_gen else 0.0 + row.extend([f"{prefill_gain:.1f}", f"{gen_gain:.1f}"]) + print(",".join(row)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Create a two-panel comparison chart from ds4-bench CSV files." + ) + parser.add_argument("runs", nargs="+", type=Path, help="ds4-bench CSV files; first is the baseline") + parser.add_argument( + "-o", + "--output", + type=Path, + default=Path("/tmp/ds4-bench-compare.png"), + help="output chart path; must end in .png", + ) + parser.add_argument("--before-label", default="standard kernel") + parser.add_argument("--after-label", default="Metal Tensor") + parser.add_argument("--labels", nargs="+", help="Labels for each CSV run.") + parser.add_argument("--title", default="ds4-bench Speed Comparison") + parser.add_argument("--no-values", action="store_true", help="hide per-point value labels") + return parser.parse_args() + + +if __name__ == "__main__": + build_chart(parse_args()) diff --git a/speed-bench/compare_logit_drift.py b/speed-bench/compare_logit_drift.py new file mode 100644 index 00000000..140d68ee --- /dev/null +++ b/speed-bench/compare_logit_drift.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +"""Compare full-logit dumps produced by ./ds4 --dump-logits. + +Example: + ./ds4 -m q2.gguf --metal -mt off --dump-logits /tmp/q2-off.json \ + --nothink --prompt-file prompt.txt + ./ds4 -m q2.gguf --metal -mt auto --dump-logits /tmp/q2-mt.json \ + --nothink --prompt-file prompt.txt + ./ds4 -m q4.gguf --metal -mt off --dump-logits /tmp/q4-off.json \ + --nothink --prompt-file prompt.txt + python3 speed-bench/compare_logit_drift.py /tmp/q2-off.json \ + /tmp/q2-mt.json /tmp/q4-off.json --labels q2_mt q4_off +""" + +from __future__ import annotations + +import argparse +import json +import math +from heapq import nlargest +from pathlib import Path +from typing import Any + + +def load_dump(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as fp: + data = json.load(fp) + logits_raw = data.get("logits") + if not isinstance(logits_raw, list) or not logits_raw: + raise SystemExit(f"{path}: missing non-empty logits array") + logits = [float("nan") if v is None else float(v) for v in logits_raw] + vocab = int(data.get("vocab", len(logits))) + if vocab != len(logits): + raise SystemExit(f"{path}: vocab={vocab} does not match logits={len(logits)}") + data["logits"] = logits + data["_path"] = str(path) + return data + + +def dump_label(data: dict[str, Any]) -> str: + model = Path(str(data.get("model", data.get("_path", "dump")))).name + quant = data.get("quant_bits", "?") + mt = data.get("mt", "?") + return f"{model}:q{quant}:mt={mt}" + + +def finite_indices(logits: list[float]) -> list[int]: + return [i for i, v in enumerate(logits) if math.isfinite(v)] + + +def topk(logits: list[float], k: int) -> list[int]: + # Match the C test's tie behavior: higher logit first, lower token id first. + return nlargest(k, finite_indices(logits), key=lambda i: (logits[i], -i)) + + +def overlap(a: list[int], b: list[int], k: int) -> int: + return len(set(a[:k]) & set(b[:k])) + + +def rank_delta(ref_top: list[int], cand_top: list[int]) -> int: + cand_rank = {token: i for i, token in enumerate(cand_top)} + worst = 0 + for i, token in enumerate(ref_top): + if token in cand_rank: + worst = max(worst, abs(cand_rank[token] - i)) + return worst + + +def top_union_max_abs( + ref: list[float], + cand: list[float], + ref_top: list[int], + cand_top: list[int], + k: int, +) -> float: + ids = set(ref_top[:k]) | set(cand_top[:k]) + worst = 0.0 + for token in ids: + if math.isfinite(ref[token]) and math.isfinite(cand[token]): + worst = max(worst, abs(cand[token] - ref[token])) + return worst + + +def compare(ref_dump: dict[str, Any], cand_dump: dict[str, Any], top_k: int) -> dict[str, Any]: + ref = ref_dump["logits"] + cand = cand_dump["logits"] + if len(ref) != len(cand): + raise SystemExit( + f"vocab mismatch: {ref_dump['_path']} has {len(ref)}, " + f"{cand_dump['_path']} has {len(cand)}" + ) + + ref_top = topk(ref, top_k) + cand_top = topk(cand, top_k) + sumsq = 0.0 + max_abs = 0.0 + nonfinite = 0 + largest: list[tuple[float, int, float, float]] = [] + for token, (rv, cv) in enumerate(zip(ref, cand)): + if not math.isfinite(rv) or not math.isfinite(cv): + nonfinite += 1 + continue + delta = cv - rv + abs_delta = abs(delta) + sumsq += delta * delta + max_abs = max(max_abs, abs_delta) + if len(largest) < 5: + largest.append((abs_delta, token, rv, cv)) + largest.sort(reverse=True) + elif abs_delta > largest[-1][0]: + largest[-1] = (abs_delta, token, rv, cv) + largest.sort(reverse=True) + + return { + "same_top1": bool(ref_top and cand_top and ref_top[0] == cand_top[0]), + "ref_top1": ref_top[0] if ref_top else None, + "cand_top1": cand_top[0] if cand_top else None, + "top5_overlap": overlap(ref_top, cand_top, min(5, top_k)), + "top20_overlap": overlap(ref_top, cand_top, min(20, top_k)), + "top_k": top_k, + "max_rank_delta": rank_delta(ref_top, cand_top), + "rms": math.sqrt(sumsq / len(ref)), + "max_abs": max_abs, + "top20_max_abs": top_union_max_abs(ref, cand, ref_top, cand_top, min(20, top_k)), + "nonfinite": nonfinite, + "largest_deltas": [ + {"token": token, "ref": rv, "cand": cv, "abs": abs_delta} + for abs_delta, token, rv, cv in largest + ], + } + + +def print_table(rows: list[dict[str, Any]]) -> None: + headers = [ + "candidate", + "same_top1", + "top5", + "top20", + "rank", + "rms", + "max_abs", + "top20_abs", + "nonfinite", + ] + print(" | ".join(headers)) + print(" | ".join("-" * len(h) for h in headers)) + for row in rows: + print( + " | ".join( + [ + row["label"], + "yes" if row["same_top1"] else "no", + f"{row['top5_overlap']}/5", + f"{row['top20_overlap']}/20", + str(row["max_rank_delta"]), + f"{row['rms']:.6g}", + f"{row['max_abs']:.6g}", + f"{row['top20_max_abs']:.6g}", + str(row["nonfinite"]), + ] + ) + ) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Compare ds4 full-logit JSON dumps from --dump-logits." + ) + parser.add_argument("reference", type=Path) + parser.add_argument("candidates", nargs="+", type=Path) + parser.add_argument("--labels", nargs="+", help="Labels for candidate dumps.") + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--json-output", type=Path) + args = parser.parse_args() + + if args.top_k < 20: + raise SystemExit("--top-k must be at least 20") + if args.labels and len(args.labels) != len(args.candidates): + raise SystemExit("--labels count must match candidate count") + + ref = load_dump(args.reference) + candidates = [load_dump(path) for path in args.candidates] + labels = args.labels or [dump_label(data) for data in candidates] + + print(f"reference: {dump_label(ref)}") + print( + "prompt_tokens: " + f"{ref.get('prompt_tokens', '?')} ctx: {ref.get('ctx', '?')} " + f"vocab: {ref.get('vocab', len(ref['logits']))}" + ) + rows = [] + for label, candidate in zip(labels, candidates): + if candidate.get("prompt_tokens") != ref.get("prompt_tokens"): + print( + f"warning: prompt token mismatch for {label}: " + f"ref={ref.get('prompt_tokens')} cand={candidate.get('prompt_tokens')}" + ) + metrics = compare(ref, candidate, args.top_k) + metrics["label"] = label + metrics["path"] = candidate["_path"] + rows.append(metrics) + + print_table(rows) + for row in rows: + print(f"\n{row['label']} largest deltas:") + for delta in row["largest_deltas"]: + print( + " token={token} ref={ref:.9g} cand={cand:.9g} abs={abs:.9g}".format( + **delta + ) + ) + + if args.json_output: + payload = { + "reference": {"path": ref["_path"], "label": dump_label(ref)}, + "rows": rows, + } + with args.json_output.open("w", encoding="utf-8") as fp: + json.dump(payload, fp, indent=2) + fp.write("\n") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/speed-bench/run_metal_tensor_bench.sh b/speed-bench/run_metal_tensor_bench.sh new file mode 100755 index 00000000..2541178f --- /dev/null +++ b/speed-bench/run_metal_tensor_bench.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "${BASH_SOURCE[0]}")/.." + +PROMPT_FILE="${PROMPT_FILE:-speed-bench/promessi_sposi.txt}" +CTX_START="${CTX_START:-512}" +CTX_MAX="${CTX_MAX:-8192}" +STEP_MUL="${STEP_MUL:-2}" +GEN_TOKENS="${GEN_TOKENS:-128}" +OUT_DIR="${OUT_DIR:-/tmp}" +PYTHON="${PYTHON:-python3}" +OPEN_CHART="${OPEN_CHART:-1}" + +mkdir -p "$OUT_DIR" + +QUALITY_CSV="$OUT_DIR/ds4_bench_quality_${GEN_TOKENS}.csv" +STANDARD_CSV="$OUT_DIR/ds4_bench_standard_metal_${GEN_TOKENS}.csv" +TENSOR_CSV="$OUT_DIR/ds4_bench_tensor_metal_${GEN_TOKENS}.csv" +CHART="$OUT_DIR/ds4_bench_standard_quality_tensor_${GEN_TOKENS}.png" + +COMMON_ARGS=( + --prompt-file "$PROMPT_FILE" + --ctx-start "$CTX_START" + --ctx-max "$CTX_MAX" + --step-mul "$STEP_MUL" + --gen-tokens "$GEN_TOKENS" +) + +echo "1/3 Quality Metal -> $QUALITY_CSV" +./ds4-bench --quality "${COMMON_ARGS[@]}" --csv "$QUALITY_CSV" + +echo "2/3 Standard Metal -> $STANDARD_CSV" +DS4_METAL_MPP_DISABLE=1 ./ds4-bench "${COMMON_ARGS[@]}" --csv "$STANDARD_CSV" + +echo "3/3 Tensor Metal -> $TENSOR_CSV" +./ds4-bench "${COMMON_ARGS[@]}" --csv "$TENSOR_CSV" + +echo "Comparing runs -> $CHART" +"$PYTHON" speed-bench/compare_bench.py \ + "$STANDARD_CSV" \ + "$QUALITY_CSV" \ + "$TENSOR_CSV" \ + --labels "Standard Metal" "Quality Metal" "Tensor Metal" \ + --title "ds4-bench: Standard vs Quality vs Tensor (${GEN_TOKENS} generated tokens)" \ + -o "$CHART" + +echo +echo "Wrote:" +echo " $QUALITY_CSV" +echo " $STANDARD_CSV" +echo " $TENSOR_CSV" +echo " $CHART" + +if [[ "$OPEN_CHART" != "0" ]]; then + if command -v open >/dev/null 2>&1; then + open "$CHART" + elif command -v xdg-open >/dev/null 2>&1; then + xdg-open "$CHART" >/dev/null 2>&1 & + else + echo "No opener found; set OPEN_CHART=0 to skip this step." + fi +fi diff --git a/tests/ds4_test.c b/tests/ds4_test.c index 959367c2..23b90563 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -150,6 +150,145 @@ static void test_metal_f16_matvec_fast_nr0_4(void) { free(weights_raw); } +static void test_metal_q8_0_mpp_matmul_case(const char *label, + uint32_t in_dim, + uint32_t out_dim, + uint32_t n_tok) { + const uint64_t blocks = in_dim / 32; + const uint64_t row_bytes = blocks * 34; + const uint64_t weight_bytes = (uint64_t)out_dim * row_bytes; + const uint64_t weight_alloc = test_round_up_u64(weight_bytes, (uint64_t)getpagesize()); + + void *weights_raw = NULL; + TEST_ASSERT(posix_memalign(&weights_raw, (size_t)getpagesize(), (size_t)weight_alloc) == 0); + if (!weights_raw) return; + + uint8_t *weights = weights_raw; + memset(weights, 0, (size_t)weight_alloc); + for (uint32_t o = 0; o < out_dim; o++) { + for (uint32_t b = 0; b < blocks; b++) { + uint8_t *block = weights + (uint64_t)o * row_bytes + (uint64_t)b * 34u; + uint16_t d = test_float_to_f16((float)((o + b) % 5u + 1u) / 128.0f); + memcpy(block, &d, sizeof(d)); + int8_t *qs = (int8_t *)(block + 2); + for (uint32_t i = 0; i < 32; i++) { + qs[i] = (int8_t)((int)((o * 5u + b * 7u + i * 3u) % 63u) - 31); + } + } + } + + const uint64_t x_bytes = (uint64_t)n_tok * in_dim * sizeof(float); + const uint64_t out_bytes = (uint64_t)n_tok * out_dim * sizeof(float); + ds4_gpu_tensor *x = ds4_gpu_tensor_alloc(x_bytes); + ds4_gpu_tensor *out_ref = ds4_gpu_tensor_alloc(out_bytes); + ds4_gpu_tensor *out_mpp = ds4_gpu_tensor_alloc(out_bytes); + TEST_ASSERT(x != NULL); + TEST_ASSERT(out_ref != NULL); + TEST_ASSERT(out_mpp != NULL); + if (!x || !out_ref || !out_mpp) { + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); + return; + } + + float *x_host = malloc((size_t)x_bytes); + float *ref_host = malloc((size_t)out_bytes); + float *mpp_host = malloc((size_t)out_bytes); + TEST_ASSERT(x_host != NULL); + TEST_ASSERT(ref_host != NULL); + TEST_ASSERT(mpp_host != NULL); + if (!x_host || !ref_host || !mpp_host) { + free(x_host); + free(ref_host); + free(mpp_host); + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); + return; + } + + for (uint32_t t = 0; t < n_tok; t++) { + for (uint32_t i = 0; i < in_dim; i++) { + x_host[(uint64_t)t * in_dim + i] = + (float)((int)((t * 19u + i * 23u) % 53u) - 26) / 80.0f; + } + } + + TEST_ASSERT(ds4_gpu_tensor_write(x, 0, x_host, x_bytes) != 0); + TEST_ASSERT(ds4_gpu_set_model_map(weights_raw, weight_alloc) != 0); + // Force quality mode ON so the reference dispatcher takes the legacy + // simdgroup path; otherwise ds4_gpu_matmul_q8_0_tensor() routes to the + // MPP variant on M5+ and the test compares two MPP outputs to each other. + ds4_gpu_set_quality(true); + TEST_ASSERT(ds4_gpu_matmul_q8_0_tensor(out_ref, weights_raw, weight_alloc, 0, + in_dim, out_dim, x, n_tok) != 0); + ds4_gpu_set_quality(false); + + int have_mpp = ds4_gpu_matmul_q8_0_mpp_tensor( + out_mpp, weights_raw, weight_alloc, 0, in_dim, out_dim, x, n_tok); + if (!have_mpp) { + fprintf(stderr, "ds4-test: skipping Tensor Q8_0 matmul %s; Metal 4 tensor API unavailable\n", + label); + free(x_host); + free(ref_host); + free(mpp_host); + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); + return; + } + + TEST_ASSERT(ds4_gpu_tensor_read(out_ref, 0, ref_host, out_bytes) != 0); + TEST_ASSERT(ds4_gpu_tensor_read(out_mpp, 0, mpp_host, out_bytes) != 0); + + float max_abs = 0.0f; + double sumsq = 0.0; + uint64_t max_index = 0; + for (uint64_t i = 0; i < (uint64_t)n_tok * out_dim; i++) { + const float err = fabsf(mpp_host[i] - ref_host[i]); + sumsq += (double)err * (double)err; + if (err > max_abs) { + max_abs = err; + max_index = i; + } + } + const float rms = (float)sqrt(sumsq / (double)((uint64_t)n_tok * out_dim)); + if (max_abs >= 0.10f) { + fprintf(stderr, + "ds4-test: Tensor Q8_0 matmul %s in=%u out=%u tok=%u max_abs=%f rms=%f at token=%llu out=%llu ref=%f tensor=%f\n", + label, in_dim, out_dim, n_tok, max_abs, rms, + (unsigned long long)(max_index / out_dim), + (unsigned long long)(max_index % out_dim), + ref_host[max_index], + mpp_host[max_index]); + } + TEST_ASSERT(max_abs < 0.10f); + + free(x_host); + free(ref_host); + free(mpp_host); + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); +} + +static void test_metal_q8_0_mpp_matmul(void) { + test_metal_q8_0_mpp_matmul_case("small_partial48", 128, 96, 48); + test_metal_q8_0_mpp_matmul_case("medium_partial48", 512, 256, 48); + test_metal_q8_0_mpp_matmul_case("modelish_full32", 4096, 256, 32); + test_metal_q8_0_mpp_matmul_case("modelish_partial48", 4096, 256, 48); +} + +static void test_metal_kernel_group(void) { + test_metal_f16_matvec_fast_nr0_4(); + test_metal_q8_0_mpp_matmul(); +} + static char *test_read_file(const char *path) { FILE *fp = fopen(path, "rb"); if (!fp) return NULL; @@ -546,6 +685,563 @@ static void test_official_logprob_vectors(void) { fclose(fp); } +#define TEST_MPP_EQ_MAX_CASES 8 +#define TEST_MPP_EQ_TOPK 20 +#define TEST_MPP_EQ_TOP5 5 +#define TEST_MPP_EQ_DELTAS 5 + +typedef struct { + char id[96]; + int ctx; + int vocab_size; + int gen_steps; + ds4_tokens prompt; + float *ref_logits; + int ref_gen[TEST_VEC_MAX_STEPS]; + int ref_gen_len; +} test_mpp_eq_case; + +typedef struct { + int ref_top1; + int cand_top1; + int overlap; + int top5_overlap; + int max_rank_delta; + int nonfinite; + float rms; + float max_abs; + float top20_max_abs; + bool same_top1; + bool pass; +} test_mpp_eq_result; + +typedef struct { + const char *label; + int cases; + int capture_failures; + int logits_failures; + int greedy_failures; + int top1_mismatches; + int min_overlap; + int min_top5_overlap; + int worst_rank_delta; + float worst_rms; + float worst_max_abs; + float worst_top20_max_abs; +} test_mpp_eq_summary; + +static void test_mpp_eq_case_free(test_mpp_eq_case *tc) { + if (!tc) return; + ds4_tokens_free(&tc->prompt); + free(tc->ref_logits); + memset(tc, 0, sizeof(*tc)); +} + +static void test_logits_topk(const float *logits, int n, int *out, int k) { + for (int i = 0; i < k; i++) out[i] = -1; + for (int id = 0; id < n; id++) { + const float v = logits[id]; + if (!isfinite(v)) continue; + for (int j = 0; j < k; j++) { + if (out[j] < 0 || v > logits[out[j]]) { + for (int l = k - 1; l > j; l--) out[l] = out[l - 1]; + out[j] = id; + break; + } + } + } +} + +static bool test_topk_contains(const int *top, int k, int id) { + for (int i = 0; i < k; i++) { + if (top[i] == id) return true; + } + return false; +} + +static int test_topk_rank(const int *top, int k, int id) { + for (int i = 0; i < k; i++) { + if (top[i] == id) return i; + } + return -1; +} + +static void test_note_delta(int *ids, float *ref_vals, float *cand_vals, + float *abs_vals, int id, float ref, float cand) { + const float abs_delta = fabsf(cand - ref); + for (int i = 0; i < TEST_MPP_EQ_DELTAS; i++) { + if (ids[i] < 0 || abs_delta > abs_vals[i]) { + for (int j = TEST_MPP_EQ_DELTAS - 1; j > i; j--) { + ids[j] = ids[j - 1]; + ref_vals[j] = ref_vals[j - 1]; + cand_vals[j] = cand_vals[j - 1]; + abs_vals[j] = abs_vals[j - 1]; + } + ids[i] = id; + ref_vals[i] = ref; + cand_vals[i] = cand; + abs_vals[i] = abs_delta; + return; + } + } +} + +static float test_top_union_max_abs(const float *ref, const float *cand, + const int *ref_top, const int *cand_top, int k) { + float max_abs = 0.0f; + for (int i = 0; i < k; i++) { + if (ref_top[i] >= 0) { + const float d = fabsf(cand[ref_top[i]] - ref[ref_top[i]]); + if (d > max_abs) max_abs = d; + } + if (cand_top[i] >= 0 && !test_topk_contains(ref_top, k, cand_top[i])) { + const float d = fabsf(cand[cand_top[i]] - ref[cand_top[i]]); + if (d > max_abs) max_abs = d; + } + } + return max_abs; +} + +static test_mpp_eq_result test_compare_mpp_logits(const test_mpp_eq_case *tc, + const float *cand_logits, + bool assert_thresholds) { + int ref_top[TEST_MPP_EQ_TOPK]; + int cand_top[TEST_MPP_EQ_TOPK]; + test_logits_topk(tc->ref_logits, tc->vocab_size, ref_top, TEST_MPP_EQ_TOPK); + test_logits_topk(cand_logits, tc->vocab_size, cand_top, TEST_MPP_EQ_TOPK); + + int overlap = 0; + int top5_overlap = 0; + int max_rank_delta = 0; + for (int i = 0; i < TEST_MPP_EQ_TOPK; i++) { + const int cand_rank = test_topk_rank(cand_top, TEST_MPP_EQ_TOPK, ref_top[i]); + if (ref_top[i] >= 0 && cand_rank >= 0) { + overlap++; + const int rank_delta = abs(cand_rank - i); + if (rank_delta > max_rank_delta) max_rank_delta = rank_delta; + } + if (i < TEST_MPP_EQ_TOP5 && + ref_top[i] >= 0 && + test_topk_contains(cand_top, TEST_MPP_EQ_TOP5, ref_top[i])) { + top5_overlap++; + } + } + + double sumsq = 0.0; + float max_abs = 0.0f; + int nonfinite = 0; + int delta_ids[TEST_MPP_EQ_DELTAS]; + float delta_ref[TEST_MPP_EQ_DELTAS]; + float delta_cand[TEST_MPP_EQ_DELTAS]; + float delta_abs[TEST_MPP_EQ_DELTAS]; + for (int i = 0; i < TEST_MPP_EQ_DELTAS; i++) { + delta_ids[i] = -1; + delta_ref[i] = 0.0f; + delta_cand[i] = 0.0f; + delta_abs[i] = 0.0f; + } + + for (int i = 0; i < tc->vocab_size; i++) { + if (!isfinite(tc->ref_logits[i]) || !isfinite(cand_logits[i])) { + nonfinite++; + continue; + } + const float delta = cand_logits[i] - tc->ref_logits[i]; + const float abs_delta = fabsf(delta); + if (abs_delta > max_abs) max_abs = abs_delta; + sumsq += (double)delta * (double)delta; + test_note_delta(delta_ids, delta_ref, delta_cand, delta_abs, + (int)i, tc->ref_logits[i], cand_logits[i]); + } + + const float rms = (float)sqrt(sumsq / (double)tc->vocab_size); + const float top_abs = test_top_union_max_abs(tc->ref_logits, cand_logits, + ref_top, cand_top, TEST_MPP_EQ_TOPK); + const bool same_top1 = ref_top[0] >= 0 && ref_top[0] == cand_top[0]; + test_mpp_eq_result result = { + .ref_top1 = ref_top[0], + .cand_top1 = cand_top[0], + .overlap = overlap, + .top5_overlap = top5_overlap, + .max_rank_delta = max_rank_delta, + .nonfinite = nonfinite, + .rms = rms, + .max_abs = max_abs, + .top20_max_abs = top_abs, + .same_top1 = same_top1, + .pass = nonfinite == 0 && same_top1, + }; + + fprintf(stderr, + "ds4-test: Tensor equivalence %s top1 ref=%d cand=%d top5_overlap=%d/%d overlap=%d/%d max_rank_delta=%d rms=%g max_abs=%g top20_max_abs=%g\n", + tc->id, ref_top[0], cand_top[0], + top5_overlap, TEST_MPP_EQ_TOP5, + overlap, TEST_MPP_EQ_TOPK, + max_rank_delta, rms, max_abs, top_abs); + fprintf(stderr, "ds4-test: Tensor equivalence %s largest deltas:", tc->id); + for (int i = 0; i < TEST_MPP_EQ_DELTAS && delta_ids[i] >= 0; i++) { + fprintf(stderr, " id=%d ref=%g cand=%g abs=%g", + delta_ids[i], delta_ref[i], delta_cand[i], delta_abs[i]); + } + fputc('\n', stderr); + + if (assert_thresholds) { + TEST_ASSERT(nonfinite == 0); + TEST_ASSERT(same_top1); + } + return result; +} + +static bool test_mpp_capture(ds4_engine *engine, const test_mpp_eq_case *tc, + float *logits, int *gen, int *gen_len) { + ds4_session *session = NULL; + TEST_ASSERT(ds4_session_create(&session, engine, tc->ctx) == 0); + if (!session) return false; + + char err[160]; + bool ok = ds4_session_sync(session, &tc->prompt, err, sizeof(err)) == 0; + TEST_ASSERT(ok); + if (ok) { + ok = ds4_session_copy_logits(session, logits, tc->vocab_size) == tc->vocab_size; + TEST_ASSERT(ok); + } + + int n = 0; + while (ok && n < tc->gen_steps) { + const int token = ds4_session_argmax(session); + gen[n++] = token; + if (n < tc->gen_steps && ds4_session_eval(session, token, err, sizeof(err)) != 0) { + ok = false; + TEST_ASSERT(false); + } + } + *gen_len = n; + + ds4_session_free(session); + return ok; +} + +static bool test_mpp_eq_case_selected(const char *id) { + const char *filter = getenv("DS4_TEST_MPP_EQ_CASE"); + if (!filter || !filter[0]) return true; + + char buf[256]; + snprintf(buf, sizeof(buf), "%s", filter); + for (char *tok = strtok(buf, ","); tok; tok = strtok(NULL, ",")) { + tok = test_trim_line(tok); + if (tok[0] && strstr(id, tok)) return true; + } + return false; +} + +static int test_load_mpp_cases(ds4_engine *engine, test_mpp_eq_case *cases, int cap) { + const char *path = getenv("DS4_TEST_VECTOR_FILE"); + if (!path || !path[0]) path = "tests/test-vectors/official.vec"; + FILE *fp = fopen(path, "rb"); + TEST_ASSERT(fp != NULL); + if (!fp) return 0; + + int ncase = 0; + test_vec_case vc; + while (ncase < cap && test_read_vector_case(fp, &vc)) { + if (!test_fill_vector_case(fp, &vc)) break; + if (!test_mpp_eq_case_selected(vc.id)) continue; + char *prompt_text = test_read_file(vc.prompt_path); + TEST_ASSERT(prompt_text != NULL); + if (!prompt_text) continue; + + test_mpp_eq_case *tc = &cases[ncase++]; + snprintf(tc->id, sizeof(tc->id), "%s", vc.id); + tc->ctx = vc.ctx; + tc->vocab_size = ds4_engine_vocab_size(engine); + tc->gen_steps = vc.nsteps < TEST_VEC_MAX_STEPS ? vc.nsteps : TEST_VEC_MAX_STEPS; + ds4_encode_chat_prompt(engine, "", prompt_text, DS4_THINK_NONE, &tc->prompt); + free(prompt_text); + TEST_ASSERT(tc->prompt.len > 0); + } + fclose(fp); + return ncase; +} + +static ds4_engine *test_open_mpp_engine(ds4_mpp_mode mode) { + ds4_engine *engine = NULL; + ds4_engine_options opt = { + .model_path = test_model_path(), + .backend = DS4_BACKEND_METAL, + .mpp_mode = mode, + }; + TEST_ASSERT(ds4_engine_open(&engine, &opt) == 0); + return engine; +} + +static void test_mpp_summary_init(test_mpp_eq_summary *summary, const char *label) { + memset(summary, 0, sizeof(*summary)); + summary->label = label; + summary->min_overlap = TEST_MPP_EQ_TOPK; + summary->min_top5_overlap = TEST_MPP_EQ_TOP5; +} + +static void test_mpp_summary_note_logits(test_mpp_eq_summary *summary, + const test_mpp_eq_result *result) { + if (!result->pass) summary->logits_failures++; + if (!result->same_top1) summary->top1_mismatches++; + if (result->overlap < summary->min_overlap) summary->min_overlap = result->overlap; + if (result->top5_overlap < summary->min_top5_overlap) { + summary->min_top5_overlap = result->top5_overlap; + } + if (result->max_rank_delta > summary->worst_rank_delta) { + summary->worst_rank_delta = result->max_rank_delta; + } + if (result->rms > summary->worst_rms) summary->worst_rms = result->rms; + if (result->max_abs > summary->worst_max_abs) summary->worst_max_abs = result->max_abs; + if (result->top20_max_abs > summary->worst_top20_max_abs) { + summary->worst_top20_max_abs = result->top20_max_abs; + } +} + +static void test_mpp_summary_print(const test_mpp_eq_summary *summary) { + fprintf(stderr, + "ds4-test: Tensor summary route=%s cases=%d capture_fail=%d logits_fail=%d greedy_fail=%d top1_mismatch=%d min_top5_overlap=%d/%d min_overlap=%d/%d worst_rank_delta=%d worst_rms=%g worst_max_abs=%g worst_top20_max_abs=%g\n", + summary->label, + summary->cases, + summary->capture_failures, + summary->logits_failures, + summary->greedy_failures, + summary->top1_mismatches, + summary->min_top5_overlap, + TEST_MPP_EQ_TOP5, + summary->min_overlap, + TEST_MPP_EQ_TOPK, + summary->worst_rank_delta, + summary->worst_rms, + summary->worst_max_abs, + summary->worst_top20_max_abs); +} + +static void test_run_mpp_candidate(const char *label, + ds4_mpp_mode mode, + test_mpp_eq_case *cases, + int ncase) { + fprintf(stderr, "ds4-test: Tensor equivalence candidate route=%s mode=%s\n", + label, ds4_mpp_mode_name(mode)); + test_mpp_eq_summary summary; + test_mpp_summary_init(&summary, label); + ds4_engine *cand_engine = test_open_mpp_engine(mode); + if (cand_engine) { + const int vocab_size = ncase > 0 ? cases[0].vocab_size : 0; + float *cand_logits = malloc((size_t)vocab_size * sizeof(cand_logits[0])); + TEST_ASSERT(cand_logits != NULL); + if (cand_logits) { + for (int i = 0; i < ncase; i++) { + test_mpp_eq_case *tc = &cases[i]; + if (!tc->ref_logits) continue; + int cand_gen[TEST_VEC_MAX_STEPS] = {0}; + int cand_gen_len = 0; + if (!test_mpp_capture(cand_engine, tc, cand_logits, cand_gen, &cand_gen_len)) { + summary.capture_failures++; + continue; + } + summary.cases++; + test_mpp_eq_result result = test_compare_mpp_logits(tc, cand_logits, true); + test_mpp_summary_note_logits(&summary, &result); + TEST_ASSERT(cand_gen_len == tc->ref_gen_len); + if (cand_gen_len != tc->ref_gen_len) summary.greedy_failures++; + for (int j = 0; j < tc->ref_gen_len && j < cand_gen_len; j++) { + if (cand_gen[j] != tc->ref_gen[j]) { + fprintf(stderr, + "ds4-test: Tensor equivalence %s greedy token mismatch step=%d ref=%d cand=%d\n", + tc->id, j, tc->ref_gen[j], cand_gen[j]); + summary.greedy_failures++; + } + TEST_ASSERT(cand_gen[j] == tc->ref_gen[j]); + } + } + free(cand_logits); + } + ds4_engine_close(cand_engine); + } + test_mpp_summary_print(&summary); +} + +static const char *const test_mpp_route_envs[] = { + "DS4_METAL_MPP_ENABLE", + "DS4_METAL_MPP_DISABLE", + "DS4_METAL_MPP_FAST", + "DS4_METAL_MPP_DIRECT_RHS", + "DS4_METAL_MPP_Q8_0_ENABLE", + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_Q8_0_DIRECT_RHS", + "DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE", + "DS4_METAL_MPP_Q8_0_FILTER", + "DS4_METAL_MPP_Q8_0_TILE_N", + "DS4_METAL_MPP_F16_ENABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_F16_DIRECT_RHS", + "DS4_METAL_MPP_F16_WIDE", + "DS4_METAL_MPP_F16_PAIR", + "DS4_METAL_MPP_ATTN_OUT_ENABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS", + "DS4_METAL_MPP_ATTN_OUT_FILTER", + "DS4_METAL_MPP_ATTN_OUT_TILE_N", + "DS4_METAL_MPP_MOE_ENABLE", + "DS4_METAL_MPP_MOE_DISABLE", + "DS4_METAL_MPP_MOE_FILTER", + "DS4_METAL_MPP_MOE_TILE_N", + "DS4_METAL_MPP_MOE_FAST_LAYOUT", + "DS4_METAL_MPP_MOE_PAIR_GATE_UP", + "DS4_METAL_MPP_MOE_START_LAYER", + "DS4_METAL_MPP_MOE_GATE_ENABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE", + "DS4_METAL_MPP_MOE_GATE_FILTER", + "DS4_METAL_MPP_MOE_GATE_START_LAYER", + "DS4_METAL_MPP_MOE_UP_ENABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE", + "DS4_METAL_MPP_MOE_UP_FILTER", + "DS4_METAL_MPP_MOE_UP_START_LAYER", + "DS4_METAL_MPP_MOE_DOWN_ENABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE", + "DS4_METAL_MPP_MOE_DOWN_FILTER", + "DS4_METAL_MPP_MOE_DOWN_START_LAYER", +}; + +typedef struct { + const char *name; + char *value; + bool had_value; +} test_mpp_saved_env; + +static void test_mpp_save_envs(test_mpp_saved_env *saved, int n) { + for (int i = 0; i < n; i++) { + saved[i].name = test_mpp_route_envs[i]; + const char *v = getenv(saved[i].name); + saved[i].had_value = v != NULL; + saved[i].value = v ? strdup(v) : NULL; + } +} + +static void test_mpp_restore_envs(test_mpp_saved_env *saved, int n) { + for (int i = 0; i < n; i++) { + if (saved[i].had_value) { + setenv(saved[i].name, saved[i].value ? saved[i].value : "", 1); + } else { + unsetenv(saved[i].name); + } + free(saved[i].value); + saved[i].value = NULL; + } +} + +static void test_mpp_clear_route_envs(void) { + for (size_t i = 0; i < sizeof(test_mpp_route_envs) / sizeof(test_mpp_route_envs[0]); i++) { + unsetenv(test_mpp_route_envs[i]); + } +} + +typedef struct { + const char *label; + ds4_mpp_mode mode; + const char *set_envs[8]; +} test_mpp_matrix_config; + +static void test_mpp_apply_matrix_config(const test_mpp_matrix_config *cfg) { + test_mpp_clear_route_envs(); + for (int i = 0; cfg->set_envs[i]; i++) { + setenv(cfg->set_envs[i], "1", 1); + } +} + +static void test_run_mpp_matrix(test_mpp_eq_case *cases, int ncase) { + const test_mpp_matrix_config configs[] = { + { "auto", DS4_MPP_AUTO, { NULL } }, + { "fast_profile", DS4_MPP_AUTO, { + "DS4_METAL_MPP_FAST", + NULL + } }, + { "q8_only", DS4_MPP_ON, { + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_DISABLE", + NULL + } }, + { "attn_out_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_MOE_DISABLE", + NULL + } }, + { "moe_gate_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE", + NULL + } }, + { "moe_up_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE", + NULL + } }, + { "moe_down_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE", + NULL + } }, + { "full_forced", DS4_MPP_ON, { NULL } }, + }; + + test_mpp_saved_env saved[sizeof(test_mpp_route_envs) / sizeof(test_mpp_route_envs[0])]; + test_mpp_save_envs(saved, (int)(sizeof(saved) / sizeof(saved[0]))); + for (size_t i = 0; i < sizeof(configs) / sizeof(configs[0]); i++) { + test_mpp_apply_matrix_config(&configs[i]); + test_run_mpp_candidate(configs[i].label, configs[i].mode, cases, ncase); + } + test_mpp_restore_envs(saved, (int)(sizeof(saved) / sizeof(saved[0]))); +} + +static void test_metal_mpp_equivalence(void) { + test_close_engines(); + + test_mpp_eq_case cases[TEST_MPP_EQ_MAX_CASES]; + memset(cases, 0, sizeof(cases)); + + ds4_engine *ref_engine = test_open_mpp_engine(DS4_MPP_OFF); + if (!ref_engine) return; + + const int ncase = test_load_mpp_cases(ref_engine, cases, TEST_MPP_EQ_MAX_CASES); + TEST_ASSERT(ncase > 0); + for (int i = 0; i < ncase; i++) { + test_mpp_eq_case *tc = &cases[i]; + tc->ref_logits = malloc((size_t)tc->vocab_size * sizeof(tc->ref_logits[0])); + TEST_ASSERT(tc->ref_logits != NULL); + if (!tc->ref_logits) continue; + TEST_ASSERT(test_mpp_capture(ref_engine, tc, + tc->ref_logits, + tc->ref_gen, + &tc->ref_gen_len)); + } + ds4_engine_close(ref_engine); + + if (getenv("DS4_TEST_MPP_EQ_MATRIX") != NULL) { + test_run_mpp_matrix(cases, ncase); + } else { + const bool force_on = getenv("DS4_TEST_MPP_EQ_FORCE_ON") != NULL; + test_run_mpp_candidate(force_on ? "forced" : "auto", + force_on ? DS4_MPP_ON : DS4_MPP_AUTO, + cases, + ncase); + } + + for (int i = 0; i < ncase; i++) test_mpp_eq_case_free(&cases[i]); +} + static const char *test_tool_call_request_json(void) { return "{" @@ -650,7 +1346,8 @@ static const ds4_test_entry test_entries[] = { {"--long-context", "long-context", "long-context story fact-recall regression", test_long_story_fact_recall}, {"--tool-call-quality", "tool-call-quality", "model emits valid DSML tool calls", test_tool_call_quality}, {"--logprob-vectors", "logprob-vectors", "official API top-logprob vector comparison", test_official_logprob_vectors}, - {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_f16_matvec_fast_nr0_4}, + {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_kernel_group}, + {"--metal-mpp-equivalence", "metal-mpp-equivalence", "Metal Tensor off/on prompt-logit and greedy equivalence", test_metal_mpp_equivalence}, #endif {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, }; @@ -671,6 +1368,9 @@ static void test_print_help(const char *prog) { puts(" DS4_TEST_MODEL=FILE Model path. Default: ds4flash.gguf"); puts(" DS4_TEST_LONG_PROMPT=FILE Rendered long-context story fact prompt."); puts(" DS4_TEST_VECTOR_FILE=FILE Simple official-vector fixture."); + puts(" DS4_TEST_MPP_EQ_CASE=NAME Run only Tensor equivalence cases whose id contains NAME."); + puts(" DS4_TEST_MPP_EQ_FORCE_ON=1 Compare -mt off against forced -mt on instead of auto."); + puts(" DS4_TEST_MPP_EQ_MATRIX=1 Run auto and isolated forced Tensor route rows."); } static const ds4_test_entry *test_find_entry(const char *arg) {