From 24022c209c4ec316cb743375b33c4c57ea9cbcc9 Mon Sep 17 00:00:00 2001 From: Audrey Tang Date: Sun, 10 May 2026 05:06:31 -0400 Subject: [PATCH 1/4] feat(loader): support stock-recipe (Q8_0/F32) GGUFs end-to-end on Metal DeepSeek-V4-Flash GGUFs produced by the upstream llama.cpp converter without per-tensor type overrides ship most of the small projections at Q8_0 (and routed-expert router weights at F32) where the antirez recipe keeps them at F16. Examples include the cyberneurova abliterated GGUFs. On stock ds4 main these load fails loudly at the first F16-strict validator (token_embd, then output_hc_fn, then hc_attn_fn, ...), and even after the validators are relaxed, several Metal kernel paths read weight bytes directly via offset arithmetic that hard-codes F16/F32 strides. This change makes the embed/HC/compressor/indexer/router validators *and* the corresponding Metal kernel paths polymorphic, so the same GGUF loads and runs with no harmonizer step. Validators (ds4.c): * New tensor_expect_dispatch_layout helper accepts F16, F32, or Q8_0 and is applied to every projection that flows through a type-dispatching matvec/matmul: output_hc_fn, hc_attn_fn, hc_ffn_fn, attn_compressor_{ape,gate,kv}, indexer.{attn_q_b,proj}, indexer_compressor_{ape,gate,kv}, ffn_gate_inp. * token_embd keeps its own inline F16/Q8_0 check because its CPU embed kernel doesn't go through matvec_any. * Two compressor decode-time guards (attn_compressor and indexer_compressor pair-projection paths) relaxed from "F16 only" to "F16 or Q8_0, paired type must match". CPU paths (ds4.c): * Refactor embed_token_f16 into an embed_token dispatcher; add embed_token_q8_0 (block-wise dequant of block_q8_0). * Replace the remaining direct matvec_f16 / matvec_f16_serial callers (HC fn, output_hc_fn, ffn_gate_inp) with the existing matvec_any dispatcher; add matvec_any_serial for the HC pre/post path. * Polymorphic Metal-side dispatch helpers metal_graph_matmul_plain_tensor and metal_graph_matmul_pair_plain_tensor (extended for Q8_0; the pair fuses with the existing F16-pair kernel when both tensors are F16, otherwise dispatches to two single matmuls). All 22 hardcoded ds4_metal_matmul_f16{,_pair}_tensor call sites in ds4.c (HC mix, attn/indexer compressors, indexer projections, output head, router) converted to use these wrappers. Metal kernels: * metal/get_rows.metal: kernel_get_rows_q8_0 (one float per thread, dequantizes its source block on the fly). * metal/dense.metal: kernel_mul_mm_f32_f32 template instantiation for the multi-token F32 weight matmul that the F32 router path needs in prefill (mirrors the existing F16/Q8_0 mul_mm_t instantiations). * metal/cpy.metal: kernel_cpy_q8_0_f32 (dequantizing 1D copy used by the compressor APE byte-strided reader). Metal wiring (ds4_metal.m): * Register g_get_rows_q8_0_pipeline and g_cpy_q8_0_f32_pipeline at init; clear them at cleanup. * Both ds4_metal_embed_{token,tokens}_hc_tensor and the shared ds4_metal_encode_get_rows helper take a new weight_type parameter (GGUF type code: 1=F16, 8=Q8_0). 8 callers in ds4.c forward weights->token_embd->type unchanged. ds4_metal_embed_row_layout picks the right per-row stride and pipeline. * ds4_metal_matmul_f32_tensor extended with a multi-token branch that dispatches to kernel_mul_mm_f32_f32 (n_tok > 1); existing n_tok = 1 path unchanged. * ds4_metal_encode_compressor_score_with_ape and the equivalent loop in ds4_metal_compressor_prefill_tensor add a Q8_0 branch (ds4_metal_encode_cpy_q8_0_f32_1d) and use a per-row stride that accounts for the block_q8_0 layout. * Six ape_type validators relaxed to also accept 8 (Q8_0). * Six ape_bytes calculations centralized through a new ds4_metal_ape_bytes(ape_type, n_elems) helper that returns the correct stride for F16/F32/Q8_0. * metal_graph_matmul_plain_tensor extended with a Q8_0 branch. Tested on macOS / M-series / Metal: * make ds4-server clean (no new warnings). * Cyberneurova Q2_K GGUF entirely unmodified: loads, prefill + decode through to coherent generation ("PASS" returned for the "reply with the single word PASS" prompt). * Pre-harmonized variant (token_embd / hc / compressor / indexer all F16, ffn_gate_inp F16): still works byte-for-byte the same as before this change, no F16 path regressions. Caveat for reviewers running ivanfioravanti's M5 PR (#15) on top of this: the unmodified cyberneurova file generates garbage (BOS spam) when MPP F16 prefill is engaged, but produces coherent output with DS4_METAL_MPP_F16_DISABLE=1. The garbage is reproducible from #15's MPP path alone and is independent of the changes here; it surfaces only because this PR makes the Q8_0 file loadable in the first place. --- ds4.c | 411 ++++++++++++++++++++++--------------------- ds4_metal.h | 7 +- ds4_metal.m | 271 +++++++++++++++++++++++----- metal/cpy.metal | 24 +++ metal/dense.metal | 3 +- metal/get_rows.metal | 34 ++++ 6 files changed, 496 insertions(+), 254 deletions(-) diff --git a/ds4.c b/ds4.c index 3142bf89..0349b455 100644 --- a/ds4.c +++ b/ds4.c @@ -2074,6 +2074,29 @@ static void tensor_expect_plain_layout( tensor_expect_layout(t, t->type, ndim, d0, d1, d2); } +/* Tensors that flow through the polymorphic dense matvec/matmul dispatch + * (matvec_any / ds4_metal_matmul_dispatch_tensor) accept any of F16, F32, or + * Q8_0. Used for the small projections antirez ships at F16 but stock + * llama.cpp converters emit at Q8_0 (HC fn weights, compressors, indexer + * projections) or F32 (router gate). */ +static void tensor_expect_dispatch_layout( + const ds4_tensor *t, + uint32_t ndim, + uint64_t d0, + uint64_t d1, + uint64_t d2) { + if (!t) ds4_die("internal error: missing tensor while validating layout"); + if (t->type != DS4_TENSOR_F16 && t->type != DS4_TENSOR_F32 && t->type != DS4_TENSOR_Q8_0) { + fprintf(stderr, + "ds4: tensor %.*s has type %s, expected F16, F32, or Q8_0\n", + (int)t->name.len, + t->name.ptr, + tensor_type_name(t->type)); + exit(1); + } + tensor_expect_layout(t, t->type, ndim, d0, d1, d2); +} + static bool tensor_is_routed_expert_type(uint32_t type) { return type == DS4_TENSOR_IQ2_XXS || type == DS4_TENSOR_Q2_K || @@ -2142,9 +2165,17 @@ static void weights_validate_layout(const ds4_weights *w) { const uint64_t q_dim = (uint64_t)DS4_N_HEAD * DS4_N_HEAD_DIM; const uint64_t out_low_dim = (uint64_t)DS4_N_OUT_GROUP * DS4_N_LORA_O; - tensor_expect_layout(w->token_embd, DS4_TENSOR_F16, 2, DS4_N_EMBD, DS4_N_VOCAB, 0); + /* token_embd may be F16 (antirez recipe) or Q8_0 (stock llama.cpp recipe used + * by e.g. cyberneurova converts). Both have CPU and Metal kernels above. */ + if (w->token_embd->type != DS4_TENSOR_F16 && w->token_embd->type != DS4_TENSOR_Q8_0) { + fprintf(stderr, + "ds4: tensor token_embd.weight has type %s, expected F16 or Q8_0\n", + tensor_type_name(w->token_embd->type)); + exit(1); + } + tensor_expect_layout(w->token_embd, w->token_embd->type, 2, DS4_N_EMBD, DS4_N_VOCAB, 0); tensor_expect_layout(w->output_hc_base, DS4_TENSOR_F32, 1, DS4_N_HC, 0, 0); - tensor_expect_layout(w->output_hc_fn, DS4_TENSOR_F16, 2, hc_dim, DS4_N_HC, 0); + tensor_expect_dispatch_layout(w->output_hc_fn, 2, hc_dim, DS4_N_HC, 0); tensor_expect_layout(w->output_hc_scale, DS4_TENSOR_F32, 1, 1, 0, 0); tensor_expect_layout(w->output_norm, DS4_TENSOR_F32, 1, DS4_N_EMBD, 0, 0); tensor_expect_layout(w->output, DS4_TENSOR_Q8_0, 2, DS4_N_EMBD, DS4_N_VOCAB, 0); @@ -2153,7 +2184,7 @@ static void weights_validate_layout(const ds4_weights *w) { const ds4_layer_weights *l = &w->layer[il]; const uint32_t ratio = ds4_layer_compress_ratio(il); - tensor_expect_layout(l->hc_attn_fn, DS4_TENSOR_F16, 2, hc_dim, hc_mix_dim, 0); + tensor_expect_dispatch_layout(l->hc_attn_fn, 2, hc_dim, hc_mix_dim, 0); tensor_expect_layout(l->hc_attn_scale, DS4_TENSOR_F32, 1, 3, 0, 0); tensor_expect_layout(l->hc_attn_base, DS4_TENSOR_F32, 1, hc_mix_dim, 0, 0); tensor_expect_layout(l->attn_norm, DS4_TENSOR_F32, 1, DS4_N_EMBD, 0, 0); @@ -2169,27 +2200,27 @@ static void weights_validate_layout(const ds4_weights *w) { if (ratio != 0) { const uint32_t coff = ratio == 4 ? 2u : 1u; const uint64_t comp_width = (uint64_t)coff * DS4_N_HEAD_DIM; - tensor_expect_layout(l->attn_compressor_ape, DS4_TENSOR_F16, 2, comp_width, ratio, 0); - tensor_expect_layout(l->attn_compressor_kv, DS4_TENSOR_F16, 2, DS4_N_EMBD, comp_width, 0); - tensor_expect_layout(l->attn_compressor_gate, DS4_TENSOR_F16, 2, DS4_N_EMBD, comp_width, 0); + tensor_expect_dispatch_layout(l->attn_compressor_ape, 2, comp_width, ratio, 0); + tensor_expect_dispatch_layout(l->attn_compressor_kv, 2, DS4_N_EMBD, comp_width, 0); + tensor_expect_dispatch_layout(l->attn_compressor_gate, 2, DS4_N_EMBD, comp_width, 0); tensor_expect_layout(l->attn_compressor_norm, DS4_TENSOR_F32, 1, DS4_N_HEAD_DIM, 0, 0); } if (ratio == 4) { const uint64_t index_q_dim = (uint64_t)DS4_N_INDEXER_HEAD * DS4_N_INDEXER_HEAD_DIM; const uint64_t index_width = 2u * DS4_N_INDEXER_HEAD_DIM; - tensor_expect_layout(l->indexer_attn_q_b, DS4_TENSOR_F16, 2, DS4_N_LORA_Q, index_q_dim, 0); - tensor_expect_layout(l->indexer_proj, DS4_TENSOR_F16, 2, DS4_N_EMBD, DS4_N_INDEXER_HEAD, 0); - tensor_expect_layout(l->indexer_compressor_ape, DS4_TENSOR_F16, 2, index_width, ratio, 0); - tensor_expect_layout(l->indexer_compressor_kv, DS4_TENSOR_F16, 2, DS4_N_EMBD, index_width, 0); - tensor_expect_layout(l->indexer_compressor_gate, DS4_TENSOR_F16, 2, DS4_N_EMBD, index_width, 0); + tensor_expect_dispatch_layout(l->indexer_attn_q_b, 2, DS4_N_LORA_Q, index_q_dim, 0); + tensor_expect_dispatch_layout(l->indexer_proj, 2, DS4_N_EMBD, DS4_N_INDEXER_HEAD, 0); + tensor_expect_dispatch_layout(l->indexer_compressor_ape, 2, index_width, ratio, 0); + tensor_expect_dispatch_layout(l->indexer_compressor_kv, 2, DS4_N_EMBD, index_width, 0); + tensor_expect_dispatch_layout(l->indexer_compressor_gate, 2, DS4_N_EMBD, index_width, 0); tensor_expect_layout(l->indexer_compressor_norm, DS4_TENSOR_F32, 1, DS4_N_INDEXER_HEAD_DIM, 0, 0); } - tensor_expect_layout(l->hc_ffn_fn, DS4_TENSOR_F16, 2, hc_dim, hc_mix_dim, 0); + tensor_expect_dispatch_layout(l->hc_ffn_fn, 2, hc_dim, hc_mix_dim, 0); tensor_expect_layout(l->hc_ffn_scale, DS4_TENSOR_F32, 1, 3, 0, 0); tensor_expect_layout(l->hc_ffn_base, DS4_TENSOR_F32, 1, hc_mix_dim, 0, 0); tensor_expect_layout(l->ffn_norm, DS4_TENSOR_F32, 1, DS4_N_EMBD, 0, 0); - tensor_expect_layout(l->ffn_gate_inp, DS4_TENSOR_F16, 2, DS4_N_EMBD, DS4_N_EXPERT, 0); + tensor_expect_dispatch_layout(l->ffn_gate_inp, 2, DS4_N_EMBD, DS4_N_EXPERT, 0); tensor_expect_optional(l->ffn_exp_probs_b, DS4_TENSOR_F32, 1, DS4_N_EXPERT, 0, 0); tensor_expect_routed_expert(l->ffn_gate_exps, 3, DS4_N_EMBD, DS4_N_FF_EXP, DS4_N_EXPERT); tensor_expect_routed_expert(l->ffn_up_exps, 3, DS4_N_EMBD, DS4_N_FF_EXP, DS4_N_EXPERT); @@ -2534,7 +2565,8 @@ static void weights_free(ds4_weights *w) { memset(w, 0, sizeof(*w)); } -/* Load one token embedding row and expand it to float activations. */ +/* Load one token embedding row and expand it to float activations. + * The CPU reference path mirrors the Metal kernel_get_rows_{f16,q8_0} dispatch. */ static void embed_token_f16(const ds4_model *m, const ds4_weights *w, int token, float *out) { ds4_tensor *te = w->token_embd; if (token < 0 || (uint64_t)token >= te->dim[1]) { @@ -2550,6 +2582,37 @@ static void embed_token_f16(const ds4_model *m, const ds4_weights *w, int token, } } +/* Q8_0 token embedding: source row is `(stride / 32)` block_q8_0 records, + * each [uint16_t scale][int8_t qs[32]]; dequantize to float on the fly. */ +static void embed_token_q8_0(const ds4_model *m, const ds4_weights *w, int token, float *out) { + ds4_tensor *te = w->token_embd; + if (token < 0 || (uint64_t)token >= te->dim[1]) { + ds4_die("token id is outside the embedding table"); + } + const uint64_t stride = te->dim[0]; + if ((stride % 32) != 0) ds4_die("Q8_0 token embedding stride is not 32-aligned"); + const uint64_t blocks = stride / 32; + const uint8_t *base = tensor_data(m, te); + const uint8_t *row = base + (uint64_t)token * blocks * 34; + for (uint64_t b = 0; b < blocks; b++) { + uint16_t scale_bits; + memcpy(&scale_bits, row + b * 34, sizeof(scale_bits)); + const int8_t *qs = (const int8_t *)(row + b * 34 + 2); + const float scale = f16_to_f32(scale_bits); + for (int i = 0; i < 32; i++) { + out[b * 32 + i] = (float)qs[i] * scale; + } + } +} + +static void embed_token(const ds4_model *m, const ds4_weights *w, int token, float *out) { + switch (w->token_embd->type) { + case DS4_TENSOR_F16: embed_token_f16(m, w, token, out); return; + case DS4_TENSOR_Q8_0: embed_token_q8_0(m, w, token, out); return; + default: ds4_die("unsupported token_embd tensor type"); + } +} + /* RMSNorm without a learned scale, used by hyper-connection control vectors. */ static void rms_norm_no_weight(float *out, const float *x, uint64_t n, float eps) { double ss = 0.0; @@ -3560,6 +3623,17 @@ static void matvec_any(float *out, const ds4_model *m, const ds4_tensor *w, cons } } +/* Serial dispatcher; for tiny matrices where parallelization overhead exceeds the gain. */ +static void matvec_any_serial(float *out, const ds4_model *m, const ds4_tensor *w, const float *x) { + switch (w->type) { + case 1: matvec_f16_serial(out, m, w, x); return; + case 0: + case 8: matvec_any(out, m, w, x); return; /* no specialized serial path; parallel is cheap here */ + default: + ds4_die("unsupported tensor type for dense matvec_serial"); + } +} + static float tensor_1d_value(const ds4_model *m, const ds4_tensor *t, uint64_t i) { if (i >= t->elements) ds4_die("tensor scalar index is out of bounds"); if (t->type == 0) { @@ -4154,9 +4228,9 @@ static void hc_pre_from_state_one_scratch( rms_norm_no_weight(flat, residual_hc, hc_dim, DS4_RMS_EPS); if (serial_fn) { - matvec_f16_serial(mix, model, fn, flat); + matvec_any_serial(mix, model, fn, flat); } else { - matvec_f16(mix, model, fn, flat); + matvec_any(mix, model, fn, flat); } const float *scale = tensor_data(model, scale_tensor); @@ -5027,7 +5101,7 @@ static void layer_router_probs_one( const float * x) { float logits[DS4_N_EXPERT]; - matvec_f16(logits, model, layer->ffn_gate_inp, x); + matvec_any(logits, model, layer->ffn_gate_inp, x); for (int i = 0; i < DS4_N_EXPERT; i++) { probs[i] = sqrtf(softplus_stable(logits[i])); } @@ -7391,7 +7465,7 @@ static void forward_token_raw_swa_cpu_decode_scratch( float *cur = scratch->cur; float *next = scratch->next; - embed_token_f16(model, weights, token, scratch->plain); + embed_token(model, weights, token, scratch->plain); hc_from_plain_embedding(cur, scratch->plain, DS4_N_EMBD, DS4_N_HC); for (uint32_t il = 0; il < DS4_N_LAYER; il++) { @@ -7458,7 +7532,7 @@ static void prefill_layer_major_cpu( } for (uint64_t t = 0; t < n_tok; t++) { - embed_token_f16(model, weights, prompt->v[t], plain); + embed_token(model, weights, prompt->v[t], plain); hc_from_plain_embedding(cur + t * hc_dim, plain, DS4_N_EMBD, DS4_N_HC); } @@ -7637,7 +7711,7 @@ static void forward_first_token_cpu( float *cur = xmalloc((size_t)DS4_N_HC * DS4_N_EMBD * sizeof(cur[0])); float *next = xmalloc((size_t)DS4_N_HC * DS4_N_EMBD * sizeof(next[0])); - embed_token_f16(model, weights, token, plain); + embed_token(model, weights, token, plain); hc_from_plain_embedding(cur, plain, DS4_N_EMBD, DS4_N_HC); for (uint32_t il = 0; il < DS4_N_LAYER; il++) { @@ -7668,7 +7742,7 @@ static void output_hc_head_one( float *w = xmalloc((size_t)n_hc * sizeof(w[0])); rms_norm_no_weight(flat, inp_hc, hc_dim, DS4_RMS_EPS); - matvec_f16(pre, model, weights->output_hc_fn, flat); + matvec_any(pre, model, weights->output_hc_fn, flat); const float *scale = tensor_data(model, weights->output_hc_scale); const float *base = tensor_data(model, weights->output_hc_base); @@ -7712,7 +7786,7 @@ static void output_logits_one_decode_scratch( const uint64_t hc_dim = (uint64_t)DS4_N_EMBD * n_hc; rms_norm_no_weight(scratch->output_flat, inp_hc, hc_dim, DS4_RMS_EPS); - matvec_f16(scratch->output_pre, model, weights->output_hc_fn, scratch->output_flat); + matvec_any(scratch->output_pre, model, weights->output_hc_fn, scratch->output_flat); const float *scale = tensor_data(model, weights->output_hc_scale); const float *base = tensor_data(model, weights->output_hc_base); @@ -8683,6 +8757,16 @@ static bool metal_graph_matmul_plain_tensor( uint64_t out_dim, const ds4_metal_tensor *x, uint64_t n_tok); +static bool metal_graph_matmul_pair_plain_tensor( + ds4_metal_tensor *out_a, + ds4_metal_tensor *out_b, + const ds4_model *model, + const ds4_tensor *w_a, + const ds4_tensor *w_b, + uint64_t in_dim, + uint64_t out_dim, + const ds4_metal_tensor *x, + uint64_t n_tok); static bool metal_graph_encode_decode_layer( ds4_metal_graph *g, @@ -8880,13 +8964,14 @@ static bool metal_graph_encode_decode_layer( const bool emit = ((pos + 1u) % ratio) == 0u; if (!layer->attn_compressor_kv || !layer->attn_compressor_gate || !layer->attn_compressor_ape || !layer->attn_compressor_norm || - layer->attn_compressor_kv->type != DS4_TENSOR_F16 || - layer->attn_compressor_gate->type != DS4_TENSOR_F16 || + (layer->attn_compressor_kv->type != DS4_TENSOR_F16 && + layer->attn_compressor_kv->type != DS4_TENSOR_Q8_0) || + layer->attn_compressor_kv->type != layer->attn_compressor_gate->type || layer->attn_compressor_kv->dim[0] != DS4_N_EMBD || layer->attn_compressor_gate->dim[0] != DS4_N_EMBD || layer->attn_compressor_kv->dim[1] != comp_width || layer->attn_compressor_gate->dim[1] != comp_width) { - fprintf(stderr, "ds4: Metal graph compressor expects paired F16 compressor projections\n"); + fprintf(stderr, "ds4: Metal graph compressor expects paired F16 or Q8_0 compressor projections\n"); ok = false; } if (ok && emit && g->layer_n_comp[il] >= g->comp_cap) { @@ -8894,25 +8979,10 @@ static bool metal_graph_encode_decode_layer( ok = false; } if (ok && !metal_graph_use_reference_compressor_pair_proj()) { - ok = ds4_metal_matmul_f16_pair_tensor(g->comp_kv_cur, - g->comp_sc_cur, - model->map, - model->size, - layer->attn_compressor_kv->abs_offset, - layer->attn_compressor_gate->abs_offset, - DS4_N_EMBD, - comp_width, - g->attn_norm, - 1) != 0; + ok = metal_graph_matmul_pair_plain_tensor(g->comp_kv_cur, g->comp_sc_cur, model, layer->attn_compressor_kv, layer->attn_compressor_gate, DS4_N_EMBD, comp_width, g->attn_norm, 1); } else { - if (ok) ok = ds4_metal_matmul_f16_tensor(g->comp_kv_cur, model->map, model->size, - layer->attn_compressor_kv->abs_offset, - DS4_N_EMBD, comp_width, - g->attn_norm, 1) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(g->comp_sc_cur, model->map, model->size, - layer->attn_compressor_gate->abs_offset, - DS4_N_EMBD, comp_width, - g->attn_norm, 1) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->comp_kv_cur, model, layer->attn_compressor_kv, DS4_N_EMBD, comp_width, g->attn_norm, 1); + if (ok) ok = metal_graph_matmul_plain_tensor(g->comp_sc_cur, model, layer->attn_compressor_gate, DS4_N_EMBD, comp_width, g->attn_norm, 1); } const uint32_t comp_row = g->layer_n_comp[il]; if (ok) ok = ds4_metal_compressor_update_tensor(g->comp_kv_cur, @@ -8960,13 +9030,14 @@ static bool metal_graph_encode_decode_layer( const uint32_t index_width = coff * DS4_N_INDEXER_HEAD_DIM; if (!layer->indexer_compressor_kv || !layer->indexer_compressor_gate || !layer->indexer_compressor_ape || !layer->indexer_compressor_norm || - layer->indexer_compressor_kv->type != DS4_TENSOR_F16 || - layer->indexer_compressor_gate->type != DS4_TENSOR_F16 || + (layer->indexer_compressor_kv->type != DS4_TENSOR_F16 && + layer->indexer_compressor_kv->type != DS4_TENSOR_Q8_0) || + layer->indexer_compressor_kv->type != layer->indexer_compressor_gate->type || layer->indexer_compressor_kv->dim[0] != DS4_N_EMBD || layer->indexer_compressor_gate->dim[0] != DS4_N_EMBD || layer->indexer_compressor_kv->dim[1] != index_width || layer->indexer_compressor_gate->dim[1] != index_width) { - fprintf(stderr, "ds4: Metal graph indexer compressor expects paired F16 projections\n"); + fprintf(stderr, "ds4: Metal graph indexer compressor expects paired F16 or Q8_0 projections\n"); ok = false; } if (ok && emit && g->layer_n_index_comp[il] >= g->comp_cap) { @@ -8974,25 +9045,10 @@ static bool metal_graph_encode_decode_layer( ok = false; } if (ok && !metal_graph_use_reference_compressor_pair_proj()) { - ok = ds4_metal_matmul_f16_pair_tensor(g->comp_kv_cur, - g->comp_sc_cur, - model->map, - model->size, - layer->indexer_compressor_kv->abs_offset, - layer->indexer_compressor_gate->abs_offset, - DS4_N_EMBD, - index_width, - g->attn_norm, - 1) != 0; + ok = metal_graph_matmul_pair_plain_tensor(g->comp_kv_cur, g->comp_sc_cur, model, layer->indexer_compressor_kv, layer->indexer_compressor_gate, DS4_N_EMBD, index_width, g->attn_norm, 1); } else { - if (ok) ok = ds4_metal_matmul_f16_tensor(g->comp_kv_cur, model->map, model->size, - layer->indexer_compressor_kv->abs_offset, - DS4_N_EMBD, index_width, - g->attn_norm, 1) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(g->comp_sc_cur, model->map, model->size, - layer->indexer_compressor_gate->abs_offset, - DS4_N_EMBD, index_width, - g->attn_norm, 1) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->comp_kv_cur, model, layer->indexer_compressor_kv, DS4_N_EMBD, index_width, g->attn_norm, 1); + if (ok) ok = metal_graph_matmul_plain_tensor(g->comp_sc_cur, model, layer->indexer_compressor_gate, DS4_N_EMBD, index_width, g->attn_norm, 1); } const uint32_t index_row = g->layer_n_index_comp[il]; if (ok) ok = ds4_metal_compressor_update_tensor(g->comp_kv_cur, @@ -9037,10 +9093,7 @@ static bool metal_graph_encode_decode_layer( fprintf(stderr, "ds4: Metal graph indexer weight projection expects F16 weights\n"); ok = false; } - if (ok) ok = ds4_metal_matmul_f16_tensor(g->indexer_q, model->map, model->size, - layer->indexer_attn_q_b->abs_offset, - q_rank, indexer_q_dim, - g->qr_norm, 1) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->indexer_q, model, layer->indexer_attn_q_b, q_rank, indexer_q_dim, g->qr_norm, 1); if (ok) ok = ds4_metal_rope_tail_tensor(g->indexer_q, 1, DS4_N_INDEXER_HEAD, DS4_N_INDEXER_HEAD_DIM, @@ -9054,10 +9107,7 @@ static bool metal_graph_encode_decode_layer( attn_factor, DS4_ROPE_YARN_BETA_FAST, DS4_ROPE_YARN_BETA_SLOW) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(g->indexer_weights, model->map, model->size, - layer->indexer_proj->abs_offset, - DS4_N_EMBD, DS4_N_INDEXER_HEAD, - g->attn_norm, 1) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->indexer_weights, model, layer->indexer_proj, DS4_N_EMBD, DS4_N_INDEXER_HEAD, g->attn_norm, 1); const float index_scale = 1.0f / sqrtf((float)(DS4_N_INDEXER_HEAD_DIM * DS4_N_INDEXER_HEAD)); if (ok && decode_index_stage_profile) { ok = metal_graph_indexer_stage_profile_boundary(NULL, @@ -9424,14 +9474,7 @@ static bool metal_graph_encode_output_head( uint64_t vocab_dim) { const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; bool ok = ds4_metal_rms_norm_plain_tensor(g->flat_hc, g->cur_hc, (uint32_t)hc_dim, DS4_RMS_EPS) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(g->output_pre, - model->map, - model->size, - weights->output_hc_fn->abs_offset, - hc_dim, - DS4_N_HC, - g->flat_hc, - 1) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->output_pre, model, weights->output_hc_fn, hc_dim, DS4_N_HC, g->flat_hc, 1); if (ok) { metal_graph_debug_dump_tensor("result_hc_pre", g->output_pre, DS4_N_HC, DS4_N_LAYER, 0); } @@ -9524,14 +9567,7 @@ static bool metal_graph_encode_output_head_batch( (uint32_t)hc_dim, n_tokens, DS4_RMS_EPS) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(output_pre, - model->map, - model->size, - weights->output_hc_fn->abs_offset, - hc_dim, - DS4_N_HC, - g->batch_flat_hc, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(output_pre, model, weights->output_hc_fn, hc_dim, DS4_N_HC, g->batch_flat_hc, n_tokens); if (ok) ok = ds4_metal_output_hc_weights_tensor(output_weights, output_pre, model->map, @@ -9578,16 +9614,54 @@ static bool metal_graph_matmul_plain_tensor( uint64_t out_dim, const ds4_metal_tensor *x, uint64_t n_tok) { + int ok = 0; const char *kind = "?"; if (w->type == DS4_TENSOR_F16) { - return ds4_metal_matmul_f16_tensor(out, model->map, model->size, - w->abs_offset, in_dim, out_dim, x, n_tok) != 0; + kind = "f16"; + ok = ds4_metal_matmul_f16_tensor(out, model->map, model->size, + w->abs_offset, in_dim, out_dim, x, n_tok); + } else if (w->type == DS4_TENSOR_F32) { + kind = "f32"; + ok = ds4_metal_matmul_f32_tensor(out, model->map, model->size, + w->abs_offset, in_dim, out_dim, x, n_tok); + } else if (w->type == DS4_TENSOR_Q8_0) { + kind = "q8_0"; + ok = ds4_metal_matmul_q8_0_tensor(out, model->map, model->size, + w->abs_offset, in_dim, out_dim, x, n_tok); + } else { + fprintf(stderr, "ds4: Metal plain matmul does not support %s\n", tensor_type_name(w->type)); + return false; } - if (w->type == DS4_TENSOR_F32) { - return ds4_metal_matmul_f32_tensor(out, model->map, model->size, - w->abs_offset, in_dim, out_dim, x, n_tok) != 0; + if (!ok) { + fprintf(stderr, "ds4: PLAIN_MATMUL FAIL kind=%s tensor=%.*s in_dim=%llu out_dim=%llu n_tok=%llu\n", + kind, (int)w->name.len, w->name.ptr, + (unsigned long long)in_dim, (unsigned long long)out_dim, (unsigned long long)n_tok); } - fprintf(stderr, "ds4: Metal plain matmul does not support %s\n", tensor_type_name(w->type)); - return false; + return ok != 0; +} + +/* Pair variant of metal_graph_matmul_plain_tensor: fused F16-pair kernel when + * both tensors are F16, otherwise dispatch to two single-tensor matmuls. */ +static bool metal_graph_matmul_pair_plain_tensor( + ds4_metal_tensor *out_a, + ds4_metal_tensor *out_b, + const ds4_model *model, + const ds4_tensor *w_a, + const ds4_tensor *w_b, + uint64_t in_dim, + uint64_t out_dim, + const ds4_metal_tensor *x, + uint64_t n_tok) { + if (w_a->type != w_b->type) { + fprintf(stderr, "ds4: Metal plain pair matmul: paired tensors must share type\n"); + return false; + } + if (w_a->type == DS4_TENSOR_F16) { + return ds4_metal_matmul_f16_pair_tensor(out_a, out_b, model->map, model->size, + w_a->abs_offset, w_b->abs_offset, + in_dim, out_dim, x, n_tok) != 0; + } + if (!metal_graph_matmul_plain_tensor(out_a, model, w_a, in_dim, out_dim, x, n_tok)) return false; + return metal_graph_matmul_plain_tensor(out_b, model, w_b, in_dim, out_dim, x, n_tok); } static bool metal_graph_encode_output_head_mtp( @@ -9917,7 +9991,7 @@ static int metal_graph_decode_test( int selected[DS4_N_EXPERT_USED]; float expert_weight[DS4_N_EXPERT_USED]; - embed_token_f16(model, weights, token, plain); + embed_token(model, weights, token, plain); hc_from_plain_embedding(cpu_hc, plain, DS4_N_EMBD, DS4_N_HC); hc_pre_from_state_one(model, layer->hc_attn_fn, @@ -9979,7 +10053,8 @@ static int metal_graph_decode_test( (uint32_t)weights->token_embd->dim[1], (uint32_t)token, DS4_N_EMBD, - DS4_N_HC) != 0; + DS4_N_HC, + weights->token_embd->type) != 0; if (ok) ok = metal_graph_encode_decode_layer(&g, model, layer, @@ -10128,7 +10203,7 @@ static int metal_graph_first_token_full_test( float *cpu_cur = xmalloc((size_t)hc_dim * sizeof(float)); float *cpu_next = xmalloc((size_t)hc_dim * sizeof(float)); - embed_token_f16(model, weights, token, plain); + embed_token(model, weights, token, plain); hc_from_plain_embedding(cpu_cur, plain, DS4_N_EMBD, DS4_N_HC); ok = ds4_metal_begin_commands() != 0; if (ok) ok = ds4_metal_embed_token_hc_tensor(g.cur_hc, @@ -10138,7 +10213,8 @@ static int metal_graph_first_token_full_test( (uint32_t)weights->token_embd->dim[1], (uint32_t)token, DS4_N_EMBD, - DS4_N_HC) != 0; + DS4_N_HC, + weights->token_embd->type) != 0; if (ok) ok = ds4_metal_end_commands() != 0; for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { @@ -10187,7 +10263,8 @@ static int metal_graph_first_token_full_test( (uint32_t)weights->token_embd->dim[1], (uint32_t)token, DS4_N_EMBD, - DS4_N_HC) != 0; + DS4_N_HC, + weights->token_embd->type) != 0; for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { ok = metal_graph_encode_decode_layer(&g, model, &weights->layer[il], @@ -10268,7 +10345,8 @@ static bool metal_graph_encode_token_raw_swa( (uint32_t)weights->token_embd->dim[1], (uint32_t)token, DS4_N_EMBD, - DS4_N_HC) != 0; + DS4_N_HC, + weights->token_embd->type) != 0; /* * Start executing the prefix of the decode graph while the CPU is still @@ -10372,24 +10450,10 @@ static bool metal_graph_refresh_ratio4_compressor_state( 4ull * DS4_N_EMBD * sizeof(float)); bool ok = tail_hc != NULL; if (ok) { - ok = ds4_metal_matmul_f16_tensor(g->batch_comp_kv, - model->map, - model->size, - kv_weight->abs_offset, - DS4_N_EMBD, - width, - tail_hc, - 4) != 0; + ok = metal_graph_matmul_plain_tensor(g->batch_comp_kv, model, kv_weight, DS4_N_EMBD, width, tail_hc, 4); } if (ok) { - ok = ds4_metal_matmul_f16_tensor(g->batch_comp_sc, - model->map, - model->size, - score_weight->abs_offset, - DS4_N_EMBD, - width, - tail_hc, - 4) != 0; + ok = metal_graph_matmul_plain_tensor(g->batch_comp_sc, model, score_weight, DS4_N_EMBD, width, tail_hc, 4); } if (ok) { ok = ds4_metal_compressor_prefill_state_ratio4_tensor(state_kv, @@ -10424,7 +10488,7 @@ static bool metal_graph_upload_prompt_embeddings_hc_cpu( float *plain = xmalloc((size_t)DS4_N_EMBD * sizeof(plain[0])); for (uint32_t t = 0; t < n_tokens; t++) { - embed_token_f16(model, weights, prompt->v[pos0 + t], plain); + embed_token(model, weights, prompt->v[pos0 + t], plain); float *dst = hc + (uint64_t)t * hc_dim; for (uint32_t h = 0; h < DS4_N_HC; h++) { memcpy(dst + (uint64_t)h * DS4_N_EMBD, @@ -10469,7 +10533,8 @@ static bool metal_graph_upload_prompt_embeddings_hc( (uint32_t)weights->token_embd->dim[1], n_tokens, DS4_N_EMBD, - DS4_N_HC) != 0; + DS4_N_HC, + weights->token_embd->type) != 0; } return metal_graph_upload_prompt_embeddings_hc_cpu(out_hc, @@ -10500,14 +10565,7 @@ static bool metal_graph_warmup_prefill_kernels( bool ok = ds4_metal_begin_commands() != 0; if (ok) { - ok = ds4_metal_matmul_f16_tensor(g->batch_hc_mix, - model->map, - model->size, - weights->layer[0].hc_attn_fn->abs_offset, - hc_dim, - mix_hc, - g->batch_flat_hc, - n_tokens) != 0; + ok = metal_graph_matmul_plain_tensor(g->batch_hc_mix, model, weights->layer[0].hc_attn_fn, hc_dim, mix_hc, g->batch_flat_hc, n_tokens); } if (ok) ok = ds4_metal_end_commands() != 0; if (!ok) { @@ -10647,14 +10705,7 @@ static bool metal_graph_encode_layer_attention_batch( (uint32_t)hc_dim, n_tokens, DS4_RMS_EPS) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(hc_mix_view, - model->map, - model->size, - layer->hc_attn_fn->abs_offset, - hc_dim, - mix_hc, - g->batch_flat_hc, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(hc_mix_view, model, layer->hc_attn_fn, hc_dim, mix_hc, g->batch_flat_hc, n_tokens); if (metal_graph_use_reference_hc_decode()) { if (ok) ok = ds4_metal_hc_split_sinkhorn_tensor(hc_split_view, hc_mix_view, @@ -10938,27 +10989,13 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal layer-major prefill needs attention compressor weights\n"); ok = false; } - if (ok) ok = ds4_metal_matmul_f16_tensor(g->batch_comp_kv, - model->map, - model->size, - layer->attn_compressor_kv->abs_offset, - DS4_N_EMBD, - comp_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->batch_comp_kv, model, layer->attn_compressor_kv, DS4_N_EMBD, comp_width, g->batch_attn_norm, n_tokens); if (ok) metal_graph_debug_dump_tensor("attn_comp_kv_raw", g->batch_comp_kv, (uint64_t)comp_width * n_tokens, il, pos0); - if (ok) ok = ds4_metal_matmul_f16_tensor(g->batch_comp_sc, - model->map, - model->size, - layer->attn_compressor_gate->abs_offset, - DS4_N_EMBD, - comp_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->batch_comp_sc, model, layer->attn_compressor_gate, DS4_N_EMBD, comp_width, g->batch_attn_norm, n_tokens); if (ok) metal_graph_debug_dump_tensor("attn_comp_score_raw", g->batch_comp_sc, (uint64_t)comp_width * n_tokens, @@ -11216,40 +11253,22 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal layer-major prefill needs indexer weights\n"); ok = false; } - if (ok) ok = ds4_metal_matmul_f16_tensor(g->batch_comp_kv, - model->map, - model->size, - layer->indexer_compressor_kv->abs_offset, - DS4_N_EMBD, - index_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->batch_comp_kv, model, layer->indexer_compressor_kv, DS4_N_EMBD, index_width, g->batch_attn_norm, n_tokens); if (ok) metal_graph_debug_dump_tensor("indexer_comp_kv_raw", g->batch_comp_kv, (uint64_t)index_width * n_tokens, il, pos0); - if (ok) ok = ds4_metal_matmul_f16_tensor(g->batch_comp_sc, - model->map, - model->size, - layer->indexer_compressor_gate->abs_offset, - DS4_N_EMBD, - index_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->batch_comp_sc, model, layer->indexer_compressor_gate, DS4_N_EMBD, index_width, g->batch_attn_norm, n_tokens); if (ok) metal_graph_debug_dump_tensor("indexer_comp_score_raw", g->batch_comp_sc, (uint64_t)index_width * n_tokens, il, pos0); - if (ok) ok = ds4_metal_matmul_f16_tensor(g->batch_indexer_q, - model->map, - model->size, - layer->indexer_attn_q_b->abs_offset, - q_rank, - (uint64_t)DS4_N_INDEXER_HEAD * DS4_N_INDEXER_HEAD_DIM, - g->batch_qr_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->batch_indexer_q, model, layer->indexer_attn_q_b, + q_rank, + (uint64_t)DS4_N_INDEXER_HEAD * DS4_N_INDEXER_HEAD_DIM, + g->batch_qr_norm, n_tokens); if (ok) ok = ds4_metal_rope_tail_tensor(g->batch_indexer_q, n_tokens, DS4_N_INDEXER_HEAD, @@ -11264,14 +11283,7 @@ static bool metal_graph_encode_layer_attention_batch( attn_factor, DS4_ROPE_YARN_BETA_FAST, DS4_ROPE_YARN_BETA_SLOW) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(g->batch_indexer_weights, - model->map, - model->size, - layer->indexer_proj->abs_offset, - DS4_N_EMBD, - DS4_N_INDEXER_HEAD, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->batch_indexer_weights, model, layer->indexer_proj, DS4_N_EMBD, DS4_N_INDEXER_HEAD, g->batch_attn_norm, n_tokens); if (zero_prefix) { if (ok && n_comp > g->comp_cap) { fprintf(stderr, "ds4: Metal layer-major indexer cache capacity exceeded at layer %u\n", il); @@ -11926,14 +11938,7 @@ static bool metal_graph_encode_layer_ffn_batch( (uint32_t)hc_dim, n_tokens, DS4_RMS_EPS) != 0; - if (ok) ok = ds4_metal_matmul_f16_tensor(hc_mix_view, - model->map, - model->size, - layer->hc_ffn_fn->abs_offset, - hc_dim, - mix_hc, - g->batch_flat_hc, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(hc_mix_view, model, layer->hc_ffn_fn, hc_dim, mix_hc, g->batch_flat_hc, n_tokens); if (metal_graph_use_reference_hc_decode()) { if (ok) ok = ds4_metal_hc_split_sinkhorn_tensor(hc_split_view, hc_mix_view, @@ -11981,14 +11986,7 @@ static bool metal_graph_encode_layer_ffn_batch( (uint64_t)n_tokens * DS4_N_EMBD, il, pos0); } DS4_METAL_PROFILE_FFN_STAGE("norm"); - if (ok) ok = ds4_metal_matmul_f16_tensor(g->batch_router_logits, - model->map, - model->size, - layer->ffn_gate_inp->abs_offset, - DS4_N_EMBD, - DS4_N_EXPERT, - g->batch_ffn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_plain_tensor(g->batch_router_logits, model, layer->ffn_gate_inp, DS4_N_EMBD, DS4_N_EXPERT, g->batch_ffn_norm, n_tokens); if (ok) ok = ds4_metal_router_select_batch_tensor(g->batch_router_selected, g->batch_router_weights, @@ -12139,7 +12137,9 @@ static bool metal_graph_encode_layer_batch( uint32_t pos0, uint32_t n_tokens) { bool ok = metal_graph_encode_layer_attention_batch(g, model, layer, il, pos0, n_tokens); - if (ok) ok = metal_graph_encode_layer_ffn_batch(g, model, layer, il, pos0, n_tokens); + if (!ok) { fprintf(stderr, "ds4: TRACE attn_batch FAILED il=%u pos0=%u n_tokens=%u\n", il, pos0, n_tokens); return false; } + ok = metal_graph_encode_layer_ffn_batch(g, model, layer, il, pos0, n_tokens); + if (!ok) { fprintf(stderr, "ds4: TRACE ffn_batch FAILED il=%u pos0=%u n_tokens=%u\n", il, pos0, n_tokens); return false; } if (ok) { ds4_metal_tensor *tmp = g->batch_cur_hc; g->batch_cur_hc = g->batch_next_hc; @@ -12255,7 +12255,8 @@ static bool metal_graph_eval_mtp_draft_from_hc( (uint32_t)base_weights->token_embd->dim[1], (uint32_t)token, DS4_N_EMBD, - 1) != 0; + 1, + base_weights->token_embd->type) != 0; if (ok) ok = ds4_metal_rms_norm_weight_tensor(g->mtp_enorm, g->mtp_embed, mtp_model->map, @@ -12976,7 +12977,8 @@ static bool metal_graph_verify_decode2_exact( (uint32_t)weights->token_embd->dim[1], (uint32_t)token0, DS4_N_EMBD, - DS4_N_HC) != 0; + DS4_N_HC, + weights->token_embd->type) != 0; if (ok) ok = ds4_metal_embed_token_hc_tensor(cur1, model->map, model->size, @@ -12984,7 +12986,8 @@ static bool metal_graph_verify_decode2_exact( (uint32_t)weights->token_embd->dim[1], (uint32_t)token1, DS4_N_EMBD, - DS4_N_HC) != 0; + DS4_N_HC, + weights->token_embd->type) != 0; ds4_metal_tensor *saved_cur = g->cur_hc; ds4_metal_tensor *saved_after = g->after_ffn_hc; @@ -13405,7 +13408,7 @@ static void embed_prompt( uint32_t n_embd, float * out) { for (int i = 0; i < tokens->len; i++) { - embed_token_f16(model, weights, tokens->v[i], out + (uint64_t)i * n_embd); + embed_token(model, weights, tokens->v[i], out + (uint64_t)i * n_embd); } } diff --git a/ds4_metal.h b/ds4_metal.h index f84f78b8..4a68a7c2 100644 --- a/ds4_metal.h +++ b/ds4_metal.h @@ -47,6 +47,7 @@ void ds4_metal_print_memory_report(const char *label); * compressed-attention indexer that chooses visible compressed rows. */ +/* `weight_type` is the GGUF tensor type code (1 for F16, 8 for Q8_0). */ int ds4_metal_embed_token_hc_tensor( ds4_metal_tensor *out_hc, const void *model_map, @@ -55,7 +56,8 @@ int ds4_metal_embed_token_hc_tensor( uint32_t n_vocab, uint32_t token, uint32_t n_embd, - uint32_t n_hc); + uint32_t n_hc, + uint32_t weight_type); int ds4_metal_embed_tokens_hc_tensor( ds4_metal_tensor *out_hc, @@ -66,7 +68,8 @@ int ds4_metal_embed_tokens_hc_tensor( uint32_t n_vocab, uint32_t n_tokens, uint32_t n_embd, - uint32_t n_hc); + uint32_t n_hc, + uint32_t weight_type); int ds4_metal_indexer_score_one_tensor( ds4_metal_tensor *scores, diff --git a/ds4_metal.m b/ds4_metal.m index 3bdacac6..c9e63ac7 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -25,6 +25,8 @@ */ enum { + DS4_METAL_TENSOR_F16 = 1, + DS4_METAL_TENSOR_Q8_0 = 8, DS4_METAL_TENSOR_Q2_K = 10, DS4_METAL_TENSOR_Q4_K = 12, DS4_METAL_TENSOR_IQ2_XXS = 16, @@ -39,12 +41,14 @@ static id g_set_rows_f32_i32_pipeline; static id g_get_rows_f32_pipeline; static id g_get_rows_f16_pipeline; +static id g_get_rows_q8_0_pipeline; static id g_get_rows_i32_pipeline; static id g_repeat_f32_pipeline; static id g_concat_pipeline; static id g_cpy_f32_f32_pipeline; static id g_cpy_f32_f16_pipeline; static id g_cpy_f16_f32_pipeline; +static id g_cpy_q8_0_f32_pipeline; static id g_swiglu_pipeline; static id g_add_pipeline; static id g_mul_pipeline; @@ -1382,6 +1386,14 @@ static float ds4_metal_negative_infinity(void) { return v.f; } +/* Total bytes for `n_elems` of an APE-style weight tensor of `ape_type`. + * F32 (0) = 4 bytes/elem; F16 (1) = 2 bytes/elem; Q8_0 (8) = 34 bytes per + * QK8_0=32 elements (caller must ensure n_elems is 32-aligned for Q8_0). */ +static uint64_t ds4_metal_ape_bytes(uint32_t ape_type, uint64_t n_elems) { + if (ape_type == 8u) return (n_elems / 32u) * 34u; + return n_elems * (ape_type == 1u ? 2u : 4u); +} + static float ds4_metal_positive_infinity(void) { union { uint32_t u; float f; } v = { 0x7f800000u }; return v.f; @@ -1451,6 +1463,14 @@ static int ds4_metal_encode_cpy_f16_f32_1d( NSUInteger dst_off, uint32_t n); +static int ds4_metal_encode_cpy_q8_0_f32_1d( + id cb, + id src, + NSUInteger src_off, + id dst, + NSUInteger dst_off, + uint32_t n); + static int ds4_metal_encode_fill_f32_rows( id cb, id buf, @@ -2722,6 +2742,23 @@ int ds4_metal_init(void) { return 0; } + fn = [library newFunctionWithName:@"kernel_get_rows_q8_0"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_get_rows_q8_0 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + + g_get_rows_q8_0_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_get_rows_q8_0_pipeline) { + fprintf(stderr, "ds4: Metal kernel_get_rows_q8_0 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + fn = [library newFunctionWithName:@"kernel_get_rows_i32"]; if (!fn) { fprintf(stderr, "ds4: Metal kernel_get_rows_i32 function not found\n"); @@ -2841,6 +2878,22 @@ int ds4_metal_init(void) { return 0; } + fn = [library newFunctionWithName:@"kernel_cpy_q8_0_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_cpy_q8_0_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_cpy_q8_0_f32_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_cpy_q8_0_f32_pipeline) { + fprintf(stderr, "ds4: Metal kernel_cpy_q8_0_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + fn = [library newFunctionWithName:@"kernel_dsv4_fp8_kv_quantize_f32"]; if (!fn) { fprintf(stderr, "ds4: Metal kernel_dsv4_fp8_kv_quantize_f32 function not found\n"); @@ -3935,12 +3988,14 @@ void ds4_metal_cleanup(void) { g_set_rows_f32_i32_pipeline = nil; g_get_rows_f32_pipeline = nil; g_get_rows_f16_pipeline = nil; + g_get_rows_q8_0_pipeline = nil; g_get_rows_i32_pipeline = nil; g_repeat_f32_pipeline = nil; g_concat_pipeline = nil; g_cpy_f32_f32_pipeline = nil; g_cpy_f32_f16_pipeline = nil; g_cpy_f16_f32_pipeline = nil; + g_cpy_q8_0_f32_pipeline = nil; g_swiglu_pipeline = nil; g_add_pipeline = nil; g_mul_pipeline = nil; @@ -4070,7 +4125,9 @@ void ds4_metal_cleanup(void) { } } -static int ds4_metal_encode_get_rows_f16( +static int ds4_metal_embed_row_layout(uint32_t, uint32_t, uint64_t *, id *); + +static int ds4_metal_encode_get_rows( id cb, id weight, NSUInteger weight_offset, @@ -4080,12 +4137,16 @@ static int ds4_metal_encode_get_rows_f16( NSUInteger out_offset, uint32_t n_vocab, uint32_t n_tokens, - uint32_t n_embd) { + uint32_t n_embd, + uint32_t weight_type) { if (!cb || !weight || !tokens || !out || n_vocab == 0 || n_tokens == 0 || n_embd == 0) { return 0; } - const uint64_t src_row_bytes = (uint64_t)n_embd * sizeof(uint16_t); + uint64_t src_row_bytes = 0; + id pipeline = nil; + if (!ds4_metal_embed_row_layout(weight_type, n_embd, &src_row_bytes, &pipeline)) return 0; + const uint64_t dst_row_bytes = (uint64_t)n_embd * sizeof(float); const uint64_t token_bytes = (uint64_t)n_tokens * sizeof(int32_t); ds4_metal_get_rows_args args = { @@ -4104,13 +4165,13 @@ static int ds4_metal_encode_get_rows_f16( }; NSUInteger nth = (NSUInteger)n_embd; - const NSUInteger max_threads = g_get_rows_f16_pipeline.maxTotalThreadsPerThreadgroup; + const NSUInteger max_threads = pipeline.maxTotalThreadsPerThreadgroup; if (nth > max_threads) nth = max_threads; if (nth == 0) nth = 1; const NSUInteger nw0 = ((NSUInteger)n_embd + nth - 1u) / nth; id enc = ds4_metal_compute_encoder(cb); - [enc setComputePipelineState:g_get_rows_f16_pipeline]; + [enc setComputePipelineState:pipeline]; [enc setBytes:&args length:sizeof(args) atIndex:0]; [enc setBuffer:weight offset:weight_offset atIndex:1]; [enc setBuffer:tokens offset:tokens_offset atIndex:2]; @@ -4168,6 +4229,33 @@ static int ds4_metal_encode_repeat_hc_embedding( return 1; } +/* Per-row stride and gather pipeline for a token-embedding tensor type. + * Returns 0 (and a NULL pipeline) if the type is not supported. */ +static int ds4_metal_embed_row_layout( + uint32_t weight_type, + uint32_t n_embd, + uint64_t *out_src_row_bytes, + id *out_pipeline) { + switch (weight_type) { + case DS4_METAL_TENSOR_F16: + *out_src_row_bytes = (uint64_t)n_embd * sizeof(uint16_t); + *out_pipeline = g_get_rows_f16_pipeline; + return 1; + case DS4_METAL_TENSOR_Q8_0: + if ((n_embd % 32) != 0) { + fprintf(stderr, "ds4: Q8_0 token embedding stride %u is not 32-aligned\n", n_embd); + return 0; + } + /* block_q8_0 is sizeof(half) + 32*sizeof(int8_t) = 34 bytes per QK8_0 elements. */ + *out_src_row_bytes = (uint64_t)(n_embd / 32) * 34; + *out_pipeline = g_get_rows_q8_0_pipeline; + return 1; + default: + fprintf(stderr, "ds4: token embedding type %u not supported by Metal embed\n", weight_type); + return 0; + } +} + int ds4_metal_embed_token_hc_tensor( ds4_metal_tensor *out_hc, const void *model_map, @@ -4176,7 +4264,8 @@ int ds4_metal_embed_token_hc_tensor( uint32_t n_vocab, uint32_t token, uint32_t n_embd, - uint32_t n_hc) { + uint32_t n_hc, + uint32_t weight_type) { if (!g_initialized && !ds4_metal_init()) return 0; if (!out_hc || !model_map || n_vocab == 0 || token >= n_vocab || n_embd == 0 || n_hc == 0) { return 0; @@ -4190,7 +4279,11 @@ int ds4_metal_embed_token_hc_tensor( return 0; } - const uint64_t weight_bytes = (uint64_t)n_vocab * n_embd * sizeof(uint16_t); + uint64_t src_row_bytes = 0; + id pipeline = nil; + if (!ds4_metal_embed_row_layout(weight_type, n_embd, &src_row_bytes, &pipeline)) return 0; + + const uint64_t weight_bytes = (uint64_t)n_vocab * src_row_bytes; if (weight_offset > model_size || weight_bytes > model_size - weight_offset) { fprintf(stderr, "ds4: Metal graph embedding range is outside the mapped model\n"); return 0; @@ -4213,7 +4306,6 @@ int ds4_metal_embed_token_hc_tensor( if (!cb) return 0; const int32_t token_i32 = (int32_t)token; - const uint64_t src_row_bytes = (uint64_t)n_embd * sizeof(uint16_t); const uint64_t dst_row_bytes = (uint64_t)n_embd * sizeof(float); ds4_metal_get_rows_args args = { .ne00t = (int32_t)n_embd, @@ -4230,12 +4322,12 @@ int ds4_metal_embed_token_hc_tensor( .nb3 = dst_row_bytes, }; NSUInteger nth = (NSUInteger)n_embd; - const NSUInteger max_threads = g_get_rows_f16_pipeline.maxTotalThreadsPerThreadgroup; + const NSUInteger max_threads = pipeline.maxTotalThreadsPerThreadgroup; if (nth > max_threads) nth = max_threads; if (nth == 0) nth = 1; const NSUInteger nw0 = ((NSUInteger)n_embd + nth - 1u) / nth; id enc = ds4_metal_compute_encoder(cb); - [enc setComputePipelineState:g_get_rows_f16_pipeline]; + [enc setComputePipelineState:pipeline]; [enc setBytes:&args length:sizeof(args) atIndex:0]; [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; [enc setBytes:&token_i32 length:sizeof(token_i32) atIndex:2]; @@ -4270,7 +4362,8 @@ int ds4_metal_embed_tokens_hc_tensor( uint32_t n_vocab, uint32_t n_tokens, uint32_t n_embd, - uint32_t n_hc) { + uint32_t n_hc, + uint32_t weight_type) { if (!g_initialized && !ds4_metal_init()) return 0; if (!out_hc || !tokens || !model_map || n_vocab == 0 || n_tokens == 0 || n_embd == 0 || n_hc == 0) { return 0; @@ -4288,7 +4381,11 @@ int ds4_metal_embed_tokens_hc_tensor( return 0; } - const uint64_t weight_bytes = (uint64_t)n_vocab * n_embd * sizeof(uint16_t); + uint64_t src_row_bytes = 0; + id _row_pipeline = nil; + if (!ds4_metal_embed_row_layout(weight_type, n_embd, &src_row_bytes, &_row_pipeline)) return 0; + + const uint64_t weight_bytes = (uint64_t)n_vocab * src_row_bytes; if (weight_offset > model_size || weight_bytes > model_size - weight_offset) { fprintf(stderr, "ds4: Metal graph batched embedding range is outside the mapped model\n"); return 0; @@ -4310,16 +4407,17 @@ int ds4_metal_embed_tokens_hc_tensor( id cb = ds4_metal_command_buffer(&owned); if (!cb) return 0; - if (!ds4_metal_encode_get_rows_f16(cb, - wbuf, - (NSUInteger)inner_offset, - tokbuf, - ds4_metal_tensor_offset(tokens), - g_embed_rows_buffer, - 0, - n_vocab, - n_tokens, - n_embd) || + if (!ds4_metal_encode_get_rows(cb, + wbuf, + (NSUInteger)inner_offset, + tokbuf, + ds4_metal_tensor_offset(tokens), + g_embed_rows_buffer, + 0, + n_vocab, + n_tokens, + n_embd, + weight_type) || !ds4_metal_encode_repeat_hc_embedding(cb, g_embed_rows_buffer, 0, @@ -5328,13 +5426,13 @@ int ds4_metal_matmul_f32_tensor( const ds4_metal_tensor *x, uint64_t n_tok) { if (!g_initialized && !ds4_metal_init()) return 0; - if (in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok > UINT32_MAX || n_tok != 1) return 0; + if (in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok > UINT32_MAX) return 0; @autoreleasepool { id xbuf = ds4_metal_tensor_buffer(x); id outbuf = ds4_metal_tensor_buffer(out); - const uint64_t x_bytes = in_dim * sizeof(float); - const uint64_t out_bytes = out_dim * sizeof(float); + const uint64_t x_bytes = n_tok * in_dim * sizeof(float); + const uint64_t out_bytes = n_tok * out_dim * sizeof(float); if (!xbuf || !outbuf || ds4_metal_tensor_bytes(x) < x_bytes || ds4_metal_tensor_bytes(out) < out_bytes) { @@ -5357,6 +5455,36 @@ int ds4_metal_matmul_f32_tensor( id cb = ds4_metal_command_buffer(&owned); if (!cb) return 0; + if (n_tok > 1) { + /* Multi-token prefill path: kernel_mul_mm_f32_f32 mirrors the + * existing F16/Q8_0 mul_mm_t instantiations in dense.metal. */ + const bool bc_inp = (in_dim % 32u) != 0; + const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + id pipeline = + ds4_metal_get_mul_mm_pipeline("kernel_mul_mm_f32_f32", bc_inp, bc_out); + if (!pipeline) return 0; + + ds4_metal_mul_mm_args args = ds4_metal_make_mm_args(in_dim, out_dim, n_tok, row_bytes); + + id enc = ds4_metal_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; + [enc setBuffer:xbuf offset:ds4_metal_tensor_offset(x) atIndex:2]; + [enc setBuffer:outbuf offset:ds4_metal_tensor_offset(out) atIndex:3]; + [enc setThreadgroupMemoryLength:(bc_out ? 8192u : 6144u) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + ((NSUInteger)out_dim + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_metal_end_compute_encoder(cb, enc); + + if (!ds4_metal_finish_command_buffer(cb, owned, "F32 tensor matmul")) { + return 0; + } + return 1; + } + ds4_metal_q8_0_matvec_args mv_args = ds4_metal_make_f32_mv_args(in_dim, out_dim, 1); ds4_metal_mv_dispatch mv_dispatch = ds4_metal_make_plain_mv_dispatch(in_dim, 1); mv_args.nr0 = mv_dispatch.nr0; @@ -6146,7 +6274,15 @@ static int ds4_metal_encode_compressor_score_with_ape( uint32_t n_tokens) { if (!cb || !score_src || !score_dst || !apebuf || width == 0 || ratio == 0 || n_tokens == 0 || - (ape_type != 0u && ape_type != 1u)) { + (ape_type != 0u && ape_type != 1u && ape_type != 8u)) { + return 0; + } + /* Q8_0 ape rows must be QK8_0-aligned (the dequant kernel walks whole + * blocks and src_off must land on a block boundary). */ + if (ape_type == 8u && (width % 32u) != 0u) { + fprintf(stderr, + "ds4: Metal compressor APE Q8_0 width %u is not 32-aligned\n", + width); return 0; } @@ -6164,14 +6300,18 @@ static int ds4_metal_encode_compressor_score_with_ape( return 0; } - const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; + /* Per-row source stride. F16/F32 use natural element bytes; Q8_0 uses + * (width/32) * sizeof(block_q8_0) = (width/32) * 34 bytes. */ + const uint64_t row_src_bytes = ape_type == 8u + ? (uint64_t)(width / 32u) * 34u + : (uint64_t)width * (ape_type == 1u ? 2u : 4u); uint32_t copied_rows = 0; uint32_t pos_mod = pos0 % ratio; while (copied_rows < n_tokens) { uint32_t seg_rows = ratio - pos_mod; if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; const uint32_t seg_elems = seg_rows * width; - const NSUInteger src_off = ape_offset + (NSUInteger)pos_mod * width * elem_ape; + const NSUInteger src_off = ape_offset + (NSUInteger)pos_mod * row_src_bytes; const NSUInteger dst_off = (NSUInteger)copied_rows * width * sizeof(float); int ok; if (ape_type == 1u) { @@ -6181,6 +6321,13 @@ static int ds4_metal_encode_compressor_score_with_ape( g_compressor_store_ape_buffer, dst_off, seg_elems); + } else if (ape_type == 8u) { + ok = ds4_metal_encode_cpy_q8_0_f32_1d(cb, + apebuf, + src_off, + g_compressor_store_ape_buffer, + dst_off, + seg_elems); } else { ok = ds4_metal_encode_cpy_f32_f32_1d(cb, apebuf, @@ -6277,7 +6424,7 @@ static int ds4_metal_compressor_store_one_tensor( uint32_t ratio, uint32_t pos) { if (!kv || !sc || !state_kv || !state_score || !model_map || - width == 0 || ratio == 0 || (ape_type != 0u && ape_type != 1u)) { + width == 0 || ratio == 0 || (ape_type != 0u && ape_type != 1u && ape_type != 8u)) { return 0; } @@ -6287,10 +6434,9 @@ static int ds4_metal_compressor_store_one_tensor( if (!pipeline) return 0; const uint32_t state_rows = ratio == 4u ? 2u * ratio : ratio; - const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; const uint64_t row_bytes = (uint64_t)width * sizeof(float); const uint64_t state_bytes = (uint64_t)state_rows * row_bytes; - const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t ape_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)width * ratio); if (ape_offset > model_size || ape_bytes > model_size - ape_offset || ds4_metal_tensor_bytes(kv) < row_bytes || ds4_metal_tensor_bytes(sc) < row_bytes || @@ -6352,7 +6498,7 @@ int ds4_metal_compressor_store_batch_tensor( if (!g_initialized && !ds4_metal_init()) return 0; if (!kv || !sc || !state_kv || !state_score || !model_map || head_dim == 0 || ratio == 0 || n_tokens == 0 || - (ape_type != 0u && ape_type != 1u)) { + (ape_type != 0u && ape_type != 1u && ape_type != 8u)) { return 0; } @@ -6360,10 +6506,9 @@ int ds4_metal_compressor_store_batch_tensor( const uint32_t coff = ratio == 4u ? 2u : 1u; const uint32_t width = coff * head_dim; const uint32_t state_rows = coff * ratio; - const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; const uint64_t kv_bytes = (uint64_t)n_tokens * width * sizeof(float); const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); - const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t ape_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)width * ratio); if (ape_offset > model_size || ape_bytes > model_size - ape_offset) { fprintf(stderr, "ds4: Metal compressor batch APE range is outside the mapped model\n"); @@ -6427,12 +6572,13 @@ int ds4_metal_compressor_store_batch_tensor( int ok = 1; uint32_t copied_rows = 0; uint32_t pos_mod = pos0 % ratio; + const uint64_t row_src_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)width); while (ok && copied_rows < n_tokens) { uint32_t seg_rows = ratio - pos_mod; if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; const uint32_t seg_elems = seg_rows * width; const NSUInteger src_off = (NSUInteger)ape_inner + - (NSUInteger)pos_mod * width * elem_ape; + (NSUInteger)pos_mod * row_src_bytes; const NSUInteger dst_off = (NSUInteger)copied_rows * width * sizeof(float); if (ape_type == 1u) { ok = ds4_metal_encode_cpy_f16_f32_1d(cb, @@ -6441,6 +6587,13 @@ int ds4_metal_compressor_store_batch_tensor( g_compressor_store_ape_buffer, dst_off, seg_elems); + } else if (ape_type == 8u) { + ok = ds4_metal_encode_cpy_q8_0_f32_1d(cb, + apebuf, + src_off, + g_compressor_store_ape_buffer, + dst_off, + seg_elems); } else { ok = ds4_metal_encode_cpy_f32_f32_1d(cb, apebuf, @@ -6995,7 +7148,7 @@ int ds4_metal_compressor_prefill_tensor( if (!comp_cache || !state_kv || !state_score || !kv || !sc || !model_map || head_dim == 0 || ratio == 0 || n_tokens == 0 || n_rot > head_dim || (n_rot & 1u) != 0 || - (ape_type != 0u && ape_type != 1u) || + (ape_type != 0u && ape_type != 1u && ape_type != 8u) || norm_type != 0u) { return 0; } @@ -7007,11 +7160,10 @@ int ds4_metal_compressor_prefill_tensor( const uint32_t n_comp = n_tokens / ratio; const uint32_t cutoff = n_comp * ratio; const uint32_t rem = n_tokens - cutoff; - const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; const uint64_t kv_bytes = (uint64_t)n_tokens * width * sizeof(float); const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * sizeof(float); - const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t ape_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)width * ratio); const uint64_t norm_bytes = (uint64_t)head_dim * sizeof(float); if (ape_offset > model_size || ape_bytes > model_size - ape_offset || @@ -7350,7 +7502,7 @@ int ds4_metal_compressor_prefill_ratio4_replay_tensor( if (!comp_cache || !state_kv || !state_score || !kv || !sc || !model_map || head_dim == 0 || n_tokens == 0 || (n_tokens & 3u) != 0 || (pos0 & 3u) != 0 || n_rot > head_dim || (n_rot & 1u) != 0 || - (ape_type != 0u && ape_type != 1u) || + (ape_type != 0u && ape_type != 1u && ape_type != 8u) || norm_type != 0u) { return 0; } @@ -7360,11 +7512,10 @@ int ds4_metal_compressor_prefill_ratio4_replay_tensor( const uint32_t width = 2u * head_dim; const uint32_t state_rows = 8u; const uint32_t n_comp = n_tokens / ratio; - const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; const uint64_t kv_bytes = (uint64_t)n_tokens * width * sizeof(float); const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * sizeof(float); - const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t ape_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)width * ratio); const uint64_t norm_bytes = (uint64_t)head_dim * sizeof(float); if (ape_offset > model_size || ape_bytes > model_size - ape_offset || @@ -7644,7 +7795,7 @@ int ds4_metal_compressor_prefill_state_ratio4_tensor( uint32_t pos0) { if (!g_initialized && !ds4_metal_init()) return 0; if (!state_kv || !state_score || !kv_tail || !sc_tail || !model_map || - head_dim == 0 || (ape_type != 0u && ape_type != 1u)) { + head_dim == 0 || (ape_type != 0u && ape_type != 1u && ape_type != 8u)) { return 0; } @@ -7652,10 +7803,9 @@ int ds4_metal_compressor_prefill_state_ratio4_tensor( const uint32_t ratio = 4u; const uint32_t width = 2u * head_dim; const uint32_t state_rows = 8u; - const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; const uint64_t tail_bytes = (uint64_t)ratio * width * sizeof(float); const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); - const uint64_t ape_bytes = (uint64_t)ratio * width * elem_ape; + const uint64_t ape_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)ratio * width); if (ape_offset > model_size || ape_bytes > model_size - ape_offset) { fprintf(stderr, "ds4: Metal compressor prefill-state APE range is outside the mapped model\n"); @@ -7758,7 +7908,7 @@ int ds4_metal_compressor_update_tensor( if (!kv_cur || !sc_cur || !state_kv || !state_score || !comp_cache || !model_map || head_dim == 0 || ratio == 0 || n_rot > head_dim || (n_rot & 1u) != 0 || - (ape_type != 0u && ape_type != 1u) || + (ape_type != 0u && ape_type != 1u && ape_type != 8u) || norm_type != 0u) { return 0; } @@ -7768,11 +7918,10 @@ int ds4_metal_compressor_update_tensor( const uint32_t width = coff * head_dim; const uint32_t state_rows = coff * ratio; const uint32_t emit = ((pos + 1u) % ratio) == 0u ? 1u : 0u; - const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; const uint64_t kv_bytes = (uint64_t)width * sizeof(float); const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); const uint64_t comp_bytes = (uint64_t)(comp_row + (emit ? 1u : 0u)) * head_dim * sizeof(float); - const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t ape_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)width * ratio); const uint64_t norm_bytes = (uint64_t)head_dim * sizeof(float); if (ape_offset > model_size || ape_bytes > model_size - ape_offset || @@ -8493,6 +8642,34 @@ static int ds4_metal_encode_cpy_f16_f32_1d( return 1; } +/* Dequantize a contiguous Q8_0 region into F32. `n` must be a multiple of + * QK8_0 (32) and `src_off` must point at a block_q8_0 boundary in src. */ +static int ds4_metal_encode_cpy_q8_0_f32_1d( + id cb, + id src, + NSUInteger src_off, + id dst, + NSUInteger dst_off, + uint32_t n) { + if (!cb || !src || !dst || n == 0 || (n % 32u) != 0) return 0; + + ds4_metal_cpy_args args = + ds4_metal_make_cpy_1d_args(n, /*src_elem_bytes=*/0, sizeof(float)); + const NSUInteger nth = ds4_metal_cpy_threads(n, g_cpy_q8_0_f32_pipeline); + const NSUInteger groups = ((NSUInteger)n + nth - 1u) / nth; + + id enc = ds4_metal_compute_encoder(cb); + [enc setComputePipelineState:g_cpy_q8_0_f32_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:src offset:src_off atIndex:1]; + [enc setBuffer:dst offset:dst_off atIndex:2]; + [enc dispatchThreadgroups:MTLSizeMake(groups, 1, 1) + threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + ds4_metal_end_compute_encoder(cb, enc); + + return 1; +} + static int ds4_metal_encode_fill_f16_1d( id cb, id buf, diff --git a/metal/cpy.metal b/metal/cpy.metal index 3aa00ac1..f6b10d2b 100644 --- a/metal/cpy.metal +++ b/metal/cpy.metal @@ -55,3 +55,27 @@ typedef decltype(kernel_cpy_t_t) kernel_cpy_t; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; + +// Q8_0 -> F32 dequantizing 1D copy. Used by the compressor APE path so +// stock-recipe GGUFs that ship `*.compressor_ape.weight` as Q8_0 can be +// read through the same byte-strided copy that the F16/F32 ape paths use. +// args.ne00 is the total element count (must be divisible by QK8_0 = 32); +// src is a packed Q8_0 region and dst is contiguous F32. +kernel void kernel_cpy_q8_0_f32( + constant ds4_metal_args_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int n = (int) args.ne00; + const int gid = (int)(tgpig.x * ntg.x + tiitg); + if (gid >= n) return; + + device const block_q8_0 *blocks = (device const block_q8_0 *) src0; + const int blk = gid / QK8_0; + const int idx = gid - blk * QK8_0; + const float d = (float) blocks[blk].d; + device float *out = (device float *) dst; + out[gid] = (float) blocks[blk].qs[idx] * d; +} diff --git a/metal/dense.metal b/metal/dense.metal index a84927e9..fa922b45 100644 --- a/metal/dense.metal +++ b/metal/dense.metal @@ -1116,6 +1116,7 @@ kernel void kernel_mul_mm( typedef decltype(kernel_mul_mm) mul_mm_t; -// Host-visible prefill matmul variants for F16 and Q8_0 weights. +// Host-visible prefill matmul variants for F16, Q8_0, and F32 weights. template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; diff --git a/metal/get_rows.metal b/metal/get_rows.metal index 31d7aab6..34ef0098 100644 --- a/metal/get_rows.metal +++ b/metal/get_rows.metal @@ -52,3 +52,37 @@ typedef decltype(kernel_get_rows_f) get_rows_f_t; template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; + +// Q8_0 token-embedding gather: dequantize blocks while reading the row. +// Source row is `(ne00 / QK8_0)` `block_q8_0` records ([half scale][int8 qs[32]]); +// destination is `ne00` floats. Layout dispatch matches kernel_get_rows_f. +kernel void kernel_get_rows_q8_0( + constant ds4_metal_args_get_rows & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; + + const int32_t r = ((const device int32_t *) (src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; + + const int32_t i02 = i11; + const int32_t i03 = i12; + + auto psrc = (device const block_q8_0 *) (src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device float *) (dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + const int blk = ind / QK8_0; + const int idx = ind - blk * QK8_0; + const float d = (float) psrc[blk].d; + pdst[ind] = (float) psrc[blk].qs[idx] * d; + + break; + } +} From 431943fdeeede527d21369a78c845516193b2e82 Mon Sep 17 00:00:00 2001 From: Audrey Tang Date: Sun, 10 May 2026 08:22:58 -0400 Subject: [PATCH 2/4] fix(metal): correct Q8_0 ape compressor paths This PR's loader changes accept Q8_0 `*compressor_ape*` weights at the validator level, but two follow-on Metal paths still treat them as F16 (or fall through to F32) and produce silently wrong output, which shows up as -token spam in generation for any prompt long enough to exercise the multi-token compressor path on M-series hardware. 1. `kernel_cpy_q8_0_f32` (added in this PR for the prefill APE byte-strided dequant) compiles cleanly and follows the same block_q8_0 indexing pattern used by other working Q8_0 kernels in dense.metal, but emits silently wrong values for the actual ape shapes (4 rows x 1024 cols of block_q8_0). Confirmed by isolating the kernel: a CPU-side dequant of the same byte region matches gguf-py's `dequantize` reference byte-for-byte, while the Metal kernel's output is wrong. 2. `kernel_dsv4_compressor_store_one` (decode-time single-row store in metal/dsv4_kv.metal): only handled `ape_type == 1` (F16) and fell through to F32 for everything else, so Q8_0 ape was reading garbage at decode time. Fix: * Replace the prefill APE Q8_0 path in `ds4_metal_encode_compressor_score_with_ape` and `ds4_metal_compressor_store_batch_tensor` with a CPU-side dequant via two new helpers (`ds4_metal_half_bits_to_float` and `ds4_metal_cpu_dequant_q8_0_rows`) into a *per-call* private MTLBuffer. A per-call buffer is required because multiple CPU writes to the previously-shared `g_compressor_store_ape_buffer` within one command buffer collapse to the last write at execute time (Metal kernels run in encode order, but CPU writes don't participate in that ordering when the same scratch is reused). The per-call buffer is retained until cb completion via `addCompletedHandler` because Metal does not strongly retain buffers bound to encoders. * Add a Q8_0 branch to `kernel_dsv4_compressor_store_one` that walks block_q8_0 layout (uint16_t scale + 32 int8 quants per 34-byte block) inline. The buggy `kernel_cpy_q8_0_f32` Metal kernel is left in place but is no longer reached from the compressor paths; its registration in ds4_metal.m is harmless and a future debug session can either fix it or drop it. Tested on macOS / M-series / Metal: * make ds4-server clean (one pre-existing -Wpointer-sign warning from the unrelated MoE path). * Cyberneurova Q2_K GGUF entirely unmodified, default flags: 21-token prompt -> coherent generation ("An LLM, or Large Language Model, is a type of artificial intelligence"). Previously this prompt generated a few coherent tokens then token spam. * Pre-harmonized variant (token_embd / hc / compressor / indexer all F16): still works byte-for-byte the same as before this fix; no F16 / F32 path regressions. --- ds4_metal.m | 268 +++++++++++++++++++++++++++++++++----------- metal/cpy.metal | 22 +++- metal/dsv4_kv.metal | 9 ++ 3 files changed, 228 insertions(+), 71 deletions(-) diff --git a/ds4_metal.m b/ds4_metal.m index c9e63ac7..88d6dd41 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -1394,6 +1394,61 @@ static uint64_t ds4_metal_ape_bytes(uint32_t ape_type, uint64_t n_elems) { return n_elems * (ape_type == 1u ? 2u : 4u); } +/* Convert IEEE-754 half (bits) to float on the CPU. Used by the Q8_0 ape + * CPU-side dequant; the Metal-side `kernel_cpy_q8_0_f32` path turned out to + * produce silently wrong output on M5 Max for the compressor APE shapes we + * hit, so the pre-add scratch is filled from the host instead. */ +static float ds4_metal_half_bits_to_float(uint16_t bits) { + const uint32_t sign = (uint32_t)(bits >> 15) & 0x1u; + const uint32_t exph = (uint32_t)(bits >> 10) & 0x1fu; + const uint32_t mant = (uint32_t)bits & 0x3ffu; + uint32_t f32bits; + if (exph == 0u) { + if (mant == 0u) { + f32bits = sign << 31; + } else { + uint32_t e = 1u, m = mant; + while ((m & 0x400u) == 0u) { m <<= 1; e++; } + m &= 0x3ffu; + f32bits = (sign << 31) | ((127u - 15u - e + 1u) << 23) | (m << 13); + } + } else if (exph == 0x1fu) { + f32bits = (sign << 31) | (0xffu << 23) | (mant << 13); + } else { + f32bits = (sign << 31) | ((exph + 127u - 15u) << 23) | (mant << 13); + } + float v; + memcpy(&v, &f32bits, sizeof(v)); + return v; +} + +/* Dequantize `row_count` rows of `row_elems` Q8_0-stored values from `src` + * into contiguous F32 at `dst`. `row_elems` must be a multiple of QK8_0=32. + * Source layout is contiguous block_q8_0 records (no per-row padding), so + * one row is `row_elems / 32` blocks of 34 bytes each. */ +static void ds4_metal_cpu_dequant_q8_0_rows( + float *dst, + const uint8_t *src, + uint32_t row_count, + uint32_t row_elems) { + const uint32_t blocks_per_row = row_elems / 32u; + for (uint32_t r = 0; r < row_count; r++) { + const uint8_t *row = src + (uint64_t)r * (uint64_t)blocks_per_row * 34u; + float *out = dst + (uint64_t)r * row_elems; + for (uint32_t b = 0; b < blocks_per_row; b++) { + const uint8_t *blk = row + (uint64_t)b * 34u; + uint16_t scale_bits; + memcpy(&scale_bits, blk, sizeof(scale_bits)); + const float scale = ds4_metal_half_bits_to_float(scale_bits); + const int8_t *qs = (const int8_t *)(blk + 2); + float *out_blk = out + (uint64_t)b * 32u; + for (uint32_t k = 0; k < 32u; k++) { + out_blk[k] = (float)qs[k] * scale; + } + } + } +} + static float ds4_metal_positive_infinity(void) { union { uint32_t u; float f; } v = { 0x7f800000u }; return v.f; @@ -6305,40 +6360,88 @@ static int ds4_metal_encode_compressor_score_with_ape( const uint64_t row_src_bytes = ape_type == 8u ? (uint64_t)(width / 32u) * 34u : (uint64_t)width * (ape_type == 1u ? 2u : 4u); - uint32_t copied_rows = 0; - uint32_t pos_mod = pos0 % ratio; - while (copied_rows < n_tokens) { - uint32_t seg_rows = ratio - pos_mod; - if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; - const uint32_t seg_elems = seg_rows * width; - const NSUInteger src_off = ape_offset + (NSUInteger)pos_mod * row_src_bytes; - const NSUInteger dst_off = (NSUInteger)copied_rows * width * sizeof(float); - int ok; - if (ape_type == 1u) { - ok = ds4_metal_encode_cpy_f16_f32_1d(cb, - apebuf, - src_off, - g_compressor_store_ape_buffer, - dst_off, - seg_elems); - } else if (ape_type == 8u) { - ok = ds4_metal_encode_cpy_q8_0_f32_1d(cb, - apebuf, - src_off, - g_compressor_store_ape_buffer, - dst_off, - seg_elems); - } else { - ok = ds4_metal_encode_cpy_f32_f32_1d(cb, - apebuf, - src_off, - g_compressor_store_ape_buffer, - dst_off, - seg_elems); + + /* For Q8_0 ape, dequant on CPU into a *per-call* private buffer. The + * Metal-side `kernel_cpy_q8_0_f32` path produced silently wrong output on + * M5 Max for compressor APE shapes; using the shared g_compressor_store_ape_buffer + * with CPU writes also produces wrong output because multiple CPU writes + * to the same scratch in one command buffer collapse to the last write + * at execute time (Metal kernels run in encode order, but CPU writes + * don't participate in that ordering when the same scratch is reused). + * + * The local buffer is retained until cb completes via addCompletedHandler + * because Metal does NOT strongly retain buffers bound to encoders. */ + if (ape_type == 8u) { + const uint8_t *apebytes = (const uint8_t *) [apebuf contents]; + if (!apebytes) { + fprintf(stderr, "ds4: Metal compressor APE Q8_0: source buffer has no CPU contents\n"); + return 0; + } + const NSUInteger ape_call_bytes = (NSUInteger)total_elems * sizeof(float); + id ape_call_buf = [g_device newBufferWithLength:ape_call_bytes + options:MTLResourceStorageModeShared]; + if (!ape_call_buf) { + fprintf(stderr, "ds4: Metal compressor APE Q8_0: per-call scratch alloc failed\n"); + return 0; + } + float *scratch = (float *) [ape_call_buf contents]; + uint32_t copied_rows = 0; + uint32_t pos_mod = pos0 % ratio; + while (copied_rows < n_tokens) { + uint32_t seg_rows = ratio - pos_mod; + if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; + const uint64_t src_off = (uint64_t)ape_offset + (uint64_t)pos_mod * row_src_bytes; + ds4_metal_cpu_dequant_q8_0_rows(scratch + (uint64_t)copied_rows * width, + apebytes + src_off, + seg_rows, + width); + copied_rows += seg_rows; + pos_mod = 0; + } + const int add_ok = ds4_metal_encode_add_f32_1d(cb, + score_src, + score_src_offset, + ape_call_buf, + 0, + score_dst, + score_dst_offset, + total_elems); + if (add_ok) { + /* Keep ape_call_buf alive until the GPU is done reading it. */ + [cb addCompletedHandler:^(id _Nonnull __unused done) { + (void) ape_call_buf; + }]; + } + return add_ok; + } else { + uint32_t copied_rows = 0; + uint32_t pos_mod = pos0 % ratio; + while (copied_rows < n_tokens) { + uint32_t seg_rows = ratio - pos_mod; + if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; + const uint32_t seg_elems = seg_rows * width; + const NSUInteger src_off = ape_offset + (NSUInteger)pos_mod * row_src_bytes; + const NSUInteger dst_off = (NSUInteger)copied_rows * width * sizeof(float); + int ok; + if (ape_type == 1u) { + ok = ds4_metal_encode_cpy_f16_f32_1d(cb, + apebuf, + src_off, + g_compressor_store_ape_buffer, + dst_off, + seg_elems); + } else { + ok = ds4_metal_encode_cpy_f32_f32_1d(cb, + apebuf, + src_off, + g_compressor_store_ape_buffer, + dst_off, + seg_elems); + } + if (!ok) return 0; + copied_rows += seg_rows; + pos_mod = 0; } - if (!ok) return 0; - copied_rows += seg_rows; - pos_mod = 0; } return ds4_metal_encode_add_f32_1d(cb, @@ -6570,52 +6673,83 @@ int ds4_metal_compressor_store_batch_tensor( } int ok = 1; - uint32_t copied_rows = 0; - uint32_t pos_mod = pos0 % ratio; const uint64_t row_src_bytes = ds4_metal_ape_bytes(ape_type, (uint64_t)width); - while (ok && copied_rows < n_tokens) { - uint32_t seg_rows = ratio - pos_mod; - if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; - const uint32_t seg_elems = seg_rows * width; - const NSUInteger src_off = (NSUInteger)ape_inner + - (NSUInteger)pos_mod * row_src_bytes; - const NSUInteger dst_off = (NSUInteger)copied_rows * width * sizeof(float); - if (ape_type == 1u) { - ok = ds4_metal_encode_cpy_f16_f32_1d(cb, - apebuf, - src_off, - g_compressor_store_ape_buffer, - dst_off, - seg_elems); - } else if (ape_type == 8u) { - ok = ds4_metal_encode_cpy_q8_0_f32_1d(cb, - apebuf, - src_off, - g_compressor_store_ape_buffer, - dst_off, - seg_elems); - } else { - ok = ds4_metal_encode_cpy_f32_f32_1d(cb, - apebuf, - src_off, - g_compressor_store_ape_buffer, - dst_off, - seg_elems); + id ape_call_buf = nil; + id ape_consume_buf = g_compressor_store_ape_buffer; + NSUInteger ape_consume_off = 0; + if (ape_type == 8u) { + /* CPU-side dequant for Q8_0: write to a fresh per-call shared + * buffer so the data survives encode-time race vs subsequent + * kernel reads (see ds4_metal_encode_compressor_score_with_ape). */ + const uint8_t *apebytes = (const uint8_t *) [apebuf contents]; + const NSUInteger ape_call_bytes = (NSUInteger)total_elems * sizeof(float); + ape_call_buf = [g_device newBufferWithLength:ape_call_bytes + options:MTLResourceStorageModeShared]; + float *scratch = ape_call_buf ? (float *) [ape_call_buf contents] : NULL; + if (!apebytes || !scratch) { + fprintf(stderr, "ds4: Metal compressor APE Q8_0: per-call scratch alloc failed\n"); + ok = 0; + } + ape_consume_buf = ape_call_buf; + uint32_t copied_rows = 0; + uint32_t pos_mod = pos0 % ratio; + while (ok && copied_rows < n_tokens) { + uint32_t seg_rows = ratio - pos_mod; + if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; + const uint64_t src_off = (uint64_t)ape_inner + (uint64_t)pos_mod * row_src_bytes; + ds4_metal_cpu_dequant_q8_0_rows(scratch + (uint64_t)copied_rows * width, + apebytes + src_off, + seg_rows, + width); + copied_rows += seg_rows; + pos_mod = 0; + } + } else { + uint32_t copied_rows = 0; + uint32_t pos_mod = pos0 % ratio; + while (ok && copied_rows < n_tokens) { + uint32_t seg_rows = ratio - pos_mod; + if (seg_rows > n_tokens - copied_rows) seg_rows = n_tokens - copied_rows; + const uint32_t seg_elems = seg_rows * width; + const NSUInteger src_off = (NSUInteger)ape_inner + + (NSUInteger)pos_mod * row_src_bytes; + const NSUInteger dst_off = (NSUInteger)copied_rows * width * sizeof(float); + if (ape_type == 1u) { + ok = ds4_metal_encode_cpy_f16_f32_1d(cb, + apebuf, + src_off, + g_compressor_store_ape_buffer, + dst_off, + seg_elems); + } else { + ok = ds4_metal_encode_cpy_f32_f32_1d(cb, + apebuf, + src_off, + g_compressor_store_ape_buffer, + dst_off, + seg_elems); + } + copied_rows += seg_rows; + pos_mod = 0; } - copied_rows += seg_rows; - pos_mod = 0; } if (ok) { ok = ds4_metal_encode_add_f32_1d(cb, scbuf, ds4_metal_tensor_offset(sc), - g_compressor_store_ape_buffer, - 0, + ape_consume_buf, + ape_consume_off, g_compressor_store_score_buffer, 0, total_elems); } + if (ok && ape_call_buf) { + /* Retain the per-call ape buffer until the GPU is done. */ + [cb addCompletedHandler:^(id _Nonnull __unused done) { + (void) ape_call_buf; + }]; + } if (ok) { ok = ds4_metal_encode_set_rows_f32_i32(cb, state_kv, diff --git a/metal/cpy.metal b/metal/cpy.metal index f6b10d2b..b2558bab 100644 --- a/metal/cpy.metal +++ b/metal/cpy.metal @@ -60,7 +60,15 @@ template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t< // stock-recipe GGUFs that ship `*.compressor_ape.weight` as Q8_0 can be // read through the same byte-strided copy that the F16/F32 ape paths use. // args.ne00 is the total element count (must be divisible by QK8_0 = 32); -// src is a packed Q8_0 region and dst is contiguous F32. +// src is a packed Q8_0 region (sizeof(block_q8_0) = 34 bytes per QK8_0 elements) +// and dst is contiguous F32. +// +// Uses explicit byte arithmetic instead of `block_q8_0 *` indexing because +// the GGUF byte stride (34) does not match Metal's natural struct alignment +// for `block_q8_0` (which would be padded to a multiple of `alignof(half)` +// in some cases). Each thread handles one output element and re-reads the +// half scale from its block's first two bytes; that's redundant but cheap +// and the compressor APE total element count is tiny (a few thousand). kernel void kernel_cpy_q8_0_f32( constant ds4_metal_args_cpy & args, device const char * src0, @@ -72,10 +80,16 @@ kernel void kernel_cpy_q8_0_f32( const int gid = (int)(tgpig.x * ntg.x + tiitg); if (gid >= n) return; - device const block_q8_0 *blocks = (device const block_q8_0 *) src0; + constexpr int BLOCK_BYTES = 34; const int blk = gid / QK8_0; const int idx = gid - blk * QK8_0; - const float d = (float) blocks[blk].d; + device const char *bp = src0 + (uint64_t)blk * BLOCK_BYTES; + half d_h; + /* half scale lives at the first 2 bytes of the block */ + thread char *dp = (thread char *) &d_h; + dp[0] = bp[0]; dp[1] = bp[1]; + const float d = (float) d_h; + const int8_t q = (int8_t) bp[2 + idx]; device float *out = (device float *) dst; - out[gid] = (float) blocks[blk].qs[idx] * d; + out[gid] = (float) q * d; } diff --git a/metal/dsv4_kv.metal b/metal/dsv4_kv.metal index 89bd7d3a..80d5846c 100644 --- a/metal/dsv4_kv.metal +++ b/metal/dsv4_kv.metal @@ -218,6 +218,15 @@ kernel void kernel_dsv4_compressor_store_one( float ape_v; if (args.ape_type == 1u) { ape_v = (float)(((device const half *)ape)[ape_i]); + } else if (args.ape_type == 8u) { + /* Q8_0: 32 elements per 34-byte block (uint16_t scale + 32 int8 quants). */ + const uint blk = ape_i / 32u; + const uint idx = ape_i - blk * 32u; + device const uchar *bp = (device const uchar *)ape + (uint64_t)blk * 34u; + const ushort sb = (ushort)bp[0] | ((ushort)bp[1] << 8); + const half d = as_type(sb); + const int q = (int)(int8_t)bp[2u + idx]; + ape_v = (float)q * (float)d; } else { ape_v = ((device const float *)ape)[ape_i]; } From c2144e50a397ed42554f7226d119f2ca12ab1538 Mon Sep 17 00:00:00 2001 From: Audrey Tang Date: Sun, 10 May 2026 14:20:40 -0400 Subject: [PATCH 3/4] fix(loader): relax indexer decode F16-only validators to accept Q8_0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The decode-time indexer code at metal_graph_encode_decode_layer (ds4.c:9082-9095) still has two F16-only validators on indexer_attn_q_b and indexer_proj that I missed in the initial loader pass. These validators only fire after `g->layer_n_comp[il] > decode_top_k` — i.e. once the compressor has accumulated more rows than the decode-time top-k. For short generations the path isn't reached; for ~400+ token generations on stock-recipe (Q8_0) GGUFs the validator trips and the request finishes with finish_reason="error" / "Metal decode failed". The downstream calls already use metal_graph_matmul_plain_tensor (which dispatches to ds4_metal_matmul_q8_0_tensor for Q8_0). The loader-time validator at line 2211-2212 already uses tensor_expect_dispatch_layout, which accepts F16/F32/Q8_0. Only these runtime guards were stuck on F16. Reproducer (cyberneurova Q2_K, default flags): a "write a long story" prompt that generates ~800 tokens hits the validator after ~400 tokens and the request errors out. After this fix, the same prompt streams 800+ tokens cleanly. --- ds4.c | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ds4.c b/ds4.c index 0349b455..57472a75 100644 --- a/ds4.c +++ b/ds4.c @@ -9080,17 +9080,19 @@ static bool metal_graph_encode_decode_layer( if (ok && g->layer_n_comp[il] > decode_top_k) { const uint64_t indexer_q_dim = (uint64_t)DS4_N_INDEXER_HEAD * DS4_N_INDEXER_HEAD_DIM; if (!layer->indexer_attn_q_b || - layer->indexer_attn_q_b->type != DS4_TENSOR_F16 || + (layer->indexer_attn_q_b->type != DS4_TENSOR_F16 && + layer->indexer_attn_q_b->type != DS4_TENSOR_Q8_0) || layer->indexer_attn_q_b->dim[0] != q_rank || layer->indexer_attn_q_b->dim[1] != indexer_q_dim) { - fprintf(stderr, "ds4: Metal graph indexer q projection expects F16 weights\n"); + fprintf(stderr, "ds4: Metal graph indexer q projection expects F16 or Q8_0 weights\n"); ok = false; } if (ok && (!layer->indexer_proj || - layer->indexer_proj->type != DS4_TENSOR_F16 || + (layer->indexer_proj->type != DS4_TENSOR_F16 && + layer->indexer_proj->type != DS4_TENSOR_Q8_0) || layer->indexer_proj->dim[0] != DS4_N_EMBD || layer->indexer_proj->dim[1] != DS4_N_INDEXER_HEAD)) { - fprintf(stderr, "ds4: Metal graph indexer weight projection expects F16 weights\n"); + fprintf(stderr, "ds4: Metal graph indexer weight projection expects F16 or Q8_0 weights\n"); ok = false; } if (ok) ok = metal_graph_matmul_plain_tensor(g->indexer_q, model, layer->indexer_attn_q_b, q_rank, indexer_q_dim, g->qr_norm, 1); From d624188dc0cce63854c366e757e510e8fb9982ba Mon Sep 17 00:00:00 2001 From: Audrey Tang Date: Sun, 10 May 2026 14:47:01 -0400 Subject: [PATCH 4/4] chore(metal): drop dead Q8_0->F32 cpy encoder The two callers of ds4_metal_encode_cpy_q8_0_f32_1d were removed in 79b08bb (switched to CPU-side dequant to avoid an encode-time race on the shared compressor scratch buffer), leaving the function unused and tripping -Wunused-function on stock Make builds. Co-Authored-By: Claude Opus 4.7 (1M context) --- ds4_metal.m | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/ds4_metal.m b/ds4_metal.m index f6cde680..95e6a27f 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -1537,14 +1537,6 @@ static int ds4_metal_encode_cpy_f16_f32_1d( NSUInteger dst_off, uint32_t n); -static int ds4_metal_encode_cpy_q8_0_f32_1d( - id cb, - id src, - NSUInteger src_off, - id dst, - NSUInteger dst_off, - uint32_t n); - static int ds4_metal_encode_fill_f32_rows( id cb, id buf, @@ -8796,34 +8788,6 @@ static int ds4_metal_encode_cpy_f16_f32_1d( return 1; } -/* Dequantize a contiguous Q8_0 region into F32. `n` must be a multiple of - * QK8_0 (32) and `src_off` must point at a block_q8_0 boundary in src. */ -static int ds4_metal_encode_cpy_q8_0_f32_1d( - id cb, - id src, - NSUInteger src_off, - id dst, - NSUInteger dst_off, - uint32_t n) { - if (!cb || !src || !dst || n == 0 || (n % 32u) != 0) return 0; - - ds4_metal_cpy_args args = - ds4_metal_make_cpy_1d_args(n, /*src_elem_bytes=*/0, sizeof(float)); - const NSUInteger nth = ds4_metal_cpy_threads(n, g_cpy_q8_0_f32_pipeline); - const NSUInteger groups = ((NSUInteger)n + nth - 1u) / nth; - - id enc = ds4_metal_compute_encoder(cb); - [enc setComputePipelineState:g_cpy_q8_0_f32_pipeline]; - [enc setBytes:&args length:sizeof(args) atIndex:0]; - [enc setBuffer:src offset:src_off atIndex:1]; - [enc setBuffer:dst offset:dst_off atIndex:2]; - [enc dispatchThreadgroups:MTLSizeMake(groups, 1, 1) - threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - ds4_metal_end_compute_encoder(cb, enc); - - return 1; -} - static int ds4_metal_encode_fill_f16_1d( id cb, id buf,