From 87c6d3e763305483107dab33cf9f9f927b023f19 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Sun, 10 May 2026 16:54:17 +0200 Subject: [PATCH 01/16] Add Metal 4 M5 scaffold --- README.md | 52 ++++ ds4.c | 1 + ds4_gpu.h | 11 + ds4_metal.m | 629 +++++++++++++++++++++++++++++++++++++++++++--- metal/dense.metal | 99 ++++++++ metal/moe.metal | 180 +++++++++++++ tests/ds4_test.c | 125 ++++++++- 7 files changed, 1059 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 4b7c69ec..63a91e88 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,8 @@ Q4 requires the larger-memory machine class, so M3 Max Q4 numbers are `N/A`. | MacBook Pro M3 Max, 128 GB | q2 | 11709 tokens | 250.11 t/s | 21.47 t/s | | MacBook Pro M3 Max, 128 GB | q4 | short | N/A | N/A | | MacBook Pro M3 Max, 128 GB | q4 | long | N/A | N/A | +| MacBook Pro M5 Max, 128 GB | q2 | short | 87.25 t/s | 34.27 t/s | +| MacBook Pro M5 Max, 128 GB | q2 | 11707 tokens | 463.44 t/s | 25.90 t/s | | Mac Studio M3 Ultra, 512 GB | q2 | short | 84.43 t/s | 36.86 t/s | | Mac Studio M3 Ultra, 512 GB | q2 | 11709 tokens | 468.03 t/s | 27.39 t/s | | Mac Studio M3 Ultra, 512 GB | q4 | short | 78.95 t/s | 35.50 t/s | @@ -194,6 +196,56 @@ exponential sweeps. Output is CSV with one row per frontier: latest prefill interval tokens/sec, generation tokens/sec at that frontier, and `kvcache_bytes`. +## Metal 4 and M5 Neural Accelerators + +The current production path is still hand-written Metal compute kernels over +`MTLBuffer` storage. That is intentional: DS4's hot path is dominated by +quantized routed-MoE matvec/matmul, sparse compressed attention, and mmap-backed +model views, which do not map cleanly to a whole-model Core ML package. + +Metal 4 is the right next target, but it should be introduced as a feature-gated +kernel backend rather than a rewrite. On macOS 26+ with `MTLGPUFamilyMetal4`, +Apple exposes tensor resources and Metal 4 command infrastructure that can run +machine-learning work on the same GPU timeline as compute work. On M5 hardware, +Apple describes the per-GPU-core Neural Accelerators as available to developers +through the Metal 4 Tensor APIs. `DS4_METAL_MEMORY_REPORT=1` now reports the +device, Metal 4 family support, MTL4 queue availability, and whether the device +looks like an M5 Neural Accelerator target. + +The implementation follows the same conservative shape used by llama.cpp's +current Metal backend: the tensor API is disabled by default on pre-M5/pre-A19 +devices, can be forced with `DS4_METAL_TENSOR_ENABLE=1`, and can always be +disabled with `DS4_METAL_TENSOR_DISABLE=1`. At startup ds4 compiles a tiny MPP +tensor matmul probe before it lets the main Metal shader source see +`DS4_METAL_HAS_TENSOR`, so unsupported SDK/device combinations fall back to the +legacy kernels. + +The Q8_0 prefill MPP route is enabled automatically on M5/M6/A19/A20-class +Metal 4 tensor targets and can be forced with +`DS4_METAL_MPP_ENABLE=1 ./ds4 --prompt-file README.md`. It only affects prompt +batches larger than eight tokens, falls back to the legacy kernel if the Metal 4 +tensor path is unavailable, and is covered by the isolated +`./ds4_test --metal-kernels` numeric regression. It has also passed the +long-context and official logprob-vector regressions on M5. Set +`DS4_METAL_MPP_DISABLE=1` to compare or temporarily disable the MPP route. + +The routed-MoE projections also use MPP by default on M5-class Metal 4 tensor +targets for staged prefill layers: the down projection starts at layer 2, the +gate and up projections start at layer 13. This constrained route has passed +the long-context and official logprob-vector regressions. Starting down at +layer 1, or gate/up together at layer 12, fails the long-context regression, +so the boundaries are intentionally conservative. + +For the common six-routed-expert prefill shape, the down-projection expert +outputs are summed with a single Metal kernel instead of five chained add +passes. Set `DS4_METAL_MOE_SUM6_DISABLE=1` to compare or temporarily disable +that fused sum route. + +The attention-output low-projection also uses MPP by default on Metal 4 tensor +targets for full 32-token tiles, falling back to the existing indexed simdgroup +kernel for partial tiles. Set `DS4_METAL_MPP_ATTN_OUT_DISABLE=1` to isolate or +temporarily disable this route. + ## CLI One-shot prompt: diff --git a/ds4.c b/ds4.c index 51410e33..c0866bc3 100644 --- a/ds4.c +++ b/ds4.c @@ -12446,6 +12446,7 @@ static bool metal_graph_encode_layer_ffn_batch( DS4_N_EXPERT_USED, DS4_SWIGLU_CLAMP_EXP, g->batch_ffn_norm, + il, n_tokens, &g->batch_routed_mid_is_f16) != 0; if (ok) { diff --git a/ds4_gpu.h b/ds4_gpu.h index 2d16c9c9..2b33b5ea 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -139,6 +139,16 @@ int ds4_gpu_matmul_q8_0_tensor( const ds4_gpu_tensor *x, uint64_t n_tok); +int ds4_gpu_matmul_q8_0_mpp_tensor( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok); + int ds4_gpu_shared_gate_up_swiglu_q8_0_tensor( ds4_gpu_tensor *gate, ds4_gpu_tensor *up, @@ -665,6 +675,7 @@ int ds4_gpu_routed_moe_batch_tensor( uint32_t n_expert, float clamp, const ds4_gpu_tensor *x, + uint32_t layer_index, uint32_t n_tokens, bool *mid_is_f16); diff --git a/ds4_metal.m b/ds4_metal.m index 0a6ae748..03a428b7 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -48,6 +48,7 @@ static id g_cpy_f16_f32_pipeline; static id g_swiglu_pipeline; static id g_add_pipeline; +static id g_moe_sum6_pipeline; static id g_mul_pipeline; static id g_rms_norm_pipeline; static id g_rms_norm_plain_pipeline; @@ -76,9 +77,6 @@ static id g_moe_mul_mv_id_q4_k_pair_pipeline; static id g_moe_mul_mv_id_q4_k_pair_swiglu_pipeline; static id g_moe_mul_mv_id_q4_k_sum6_pipeline; -static id g_moe_mul_mm_id_iq2_xxs_pipeline; -static id g_moe_mul_mm_id_q2_k_pipeline; -static id g_moe_mul_mm_id_q4_k_pipeline; static id g_rope_tail_batch_pipeline; static id g_dsv4_fp8_kv_quantize_pipeline; static id g_dsv4_kv_fp8_store_pipeline; @@ -140,6 +138,13 @@ static uint64_t g_model_wrap_bytes; static uint64_t g_model_wrap_max_bytes; static uint64_t g_model_residency_count; +static int g_metal4_runtime_available; +static int g_metal4_family_supported; +static int g_metal4_queue_supported; +static int g_metal4_m5_neural_accelerators_hint; +static int g_metal4_tensor_api_enabled; +static int g_metal4_tensor_api_compile_supported; +static char g_metal_device_name[128]; static NSUInteger g_flash_attn_mask_bytes; static NSUInteger g_flash_attn_pad_bytes; static NSUInteger g_flash_attn_tmp_bytes; @@ -589,14 +594,16 @@ static int ds4_gpu_map_model_views( static id ds4_gpu_get_mul_mm_id_pipeline( const char *function_name, - bool bc_inp) { - NSString *key = [NSString stringWithFormat:@"%s_bci=%d", - function_name, bc_inp ? 1 : 0]; + bool bc_inp, + bool use_mpp) { + NSString *key = [NSString stringWithFormat:@"%s_bci=%d_mpp=%d", + function_name, bc_inp ? 1 : 0, use_mpp ? 1 : 0]; id cached = [g_pipeline_cache objectForKey:key]; if (cached) return cached; MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init]; [constants setConstantValue:&bc_inp type:MTLDataTypeBool atIndex:700]; + [constants setConstantValue:&use_mpp type:MTLDataTypeBool atIndex:702]; NSError *error = nil; NSString *name = [NSString stringWithUTF8String:function_name]; @@ -673,6 +680,245 @@ static int ds4_gpu_use_compressor_pair_nr4(void) { return enabled; } +static int ds4_gpu_device_name_contains(const char *needle); + +static int ds4_gpu_mpp_q8_0_default_target(void) { + return ds4_gpu_device_name_contains("M5") || + ds4_gpu_device_name_contains("M6") || + ds4_gpu_device_name_contains("A19") || + ds4_gpu_device_name_contains("A20"); +} + +static int ds4_gpu_mpp_q8_0_policy_enabled(void) { + if (!g_metal4_tensor_api_enabled) return 0; + if (getenv("DS4_METAL_MPP_DISABLE") != NULL) return 0; + if (getenv("DS4_METAL_MPP_ENABLE") != NULL) return 1; + return ds4_gpu_mpp_q8_0_default_target(); +} + +static int ds4_gpu_use_mpp_q8_0_matmul(void) { + static int initialized; + static int enabled; + if (!initialized) { + enabled = ds4_gpu_mpp_q8_0_policy_enabled(); + if (enabled) { + const int forced = getenv("DS4_METAL_MPP_ENABLE") != NULL; + fprintf(stderr, "ds4: Metal MPP Q8_0 prefill matmul enabled%s\n", + forced ? " by environment" : " by default"); + } + initialized = 1; + } + return enabled; +} + +static int ds4_gpu_use_mpp_f16_compressor_matmul(void) { + static int initialized; + static int enabled; + if (!initialized) { + enabled = ds4_gpu_mpp_q8_0_policy_enabled() && + getenv("DS4_METAL_MPP_F16_DISABLE") == NULL; + if (enabled) { + const int forced = getenv("DS4_METAL_MPP_ENABLE") != NULL; + fprintf(stderr, "ds4: Metal MPP F16 compressor prefill matmul enabled%s\n", + forced ? " by environment" : " by default"); + } + initialized = 1; + } + return enabled; +} + +static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { + static int initialized; + static int enabled; + if (!initialized) { + enabled = g_metal4_tensor_api_enabled && + getenv("DS4_METAL_MPP_DISABLE") == NULL && + getenv("DS4_METAL_MPP_ATTN_OUT_DISABLE") == NULL; + if (enabled) { + fprintf(stderr, "ds4: Metal MPP attention-output low projection enabled by default\n"); + } + initialized = 1; + } + return enabled; +} + +enum { + DS4_METAL_MOE_MPP_GATE = 1 << 0, + DS4_METAL_MOE_MPP_UP = 1 << 1, + DS4_METAL_MOE_MPP_DOWN = 1 << 2, + + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 13, + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 13, + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 2, +}; + +static int ds4_gpu_mpp_routed_moe_default_target(void) { + return ds4_gpu_device_name_contains("M5"); +} + +static int ds4_gpu_mpp_routed_moe_default_policy(void) { + return g_metal4_tensor_api_enabled && + getenv("DS4_METAL_MPP_DISABLE") == NULL && + ds4_gpu_mpp_routed_moe_default_target(); +} + +static int ds4_gpu_mpp_routed_moe_stage_mask(void) { + static int initialized; + static int mask; + if (!initialized) { + if (ds4_gpu_mpp_routed_moe_default_policy()) { + mask = DS4_METAL_MOE_MPP_GATE | DS4_METAL_MOE_MPP_UP | DS4_METAL_MOE_MPP_DOWN; + } + if (mask) { + fprintf(stderr, "ds4: Metal MPP routed MoE projections enabled by default for staged prefill layers\n"); + } + initialized = 1; + } + return mask; +} + +static int ds4_gpu_mpp_routed_moe_mask_for_layer(uint32_t layer_index) { + const int requested_mask = ds4_gpu_mpp_routed_moe_stage_mask(); + if (!requested_mask) return 0; + + if (ds4_gpu_mpp_routed_moe_default_policy()) { + static int initialized; + if (!initialized) { + fprintf(stderr, + "ds4: Metal MPP routed MoE default ranges down=%d..end up=%d..end gate=%d..end\n", + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER, + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER, + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER); + initialized = 1; + } + int mask = 0; + if (layer_index >= DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER) mask |= DS4_METAL_MOE_MPP_DOWN; + if (layer_index >= DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER) mask |= DS4_METAL_MOE_MPP_UP; + if (layer_index >= DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER) mask |= DS4_METAL_MOE_MPP_GATE; + return mask & requested_mask; + } + + return 0; +} + +static void ds4_gpu_warn_mpp_fallback(void) { + static int warned; + if (!warned) { + fprintf(stderr, "ds4: Metal MPP prefill matmul unavailable; falling back to legacy kernel\n"); + warned = 1; + } +} + +static int ds4_gpu_device_name_contains(const char *needle) { + return g_metal_device_name[0] != '\0' && strstr(g_metal_device_name, needle) != NULL; +} + +static int ds4_gpu_compile_tensor_probe(void) { +#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000 + if (!g_device) return 0; + if (@available(macOS 26.0, *)) { + const char *src = + "#include \n" + "#include \n" + "#include \n" + "using namespace metal;\n" + "using namespace mpp::tensor_ops;\n" + "kernel void ds4_tensor_probe(\n" + " tensor> A [[buffer(0)]],\n" + " tensor> B [[buffer(1)]],\n" + " device float *C [[buffer(2)]],\n" + " uint2 tgid [[threadgroup_position_in_grid]]) {\n" + " auto tA = A.slice(0, (int)tgid.y);\n" + " auto tB = B.slice((int)tgid.x, 0);\n" + " matmul2d> mm;\n" + " auto cT = mm.get_destination_cooperative_tensor();\n" + " auto sA = tA.slice(0, 0);\n" + " auto sB = tB.slice(0, 0);\n" + " mm.run(sB, sA, cT);\n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16));\n" + " cT.store(tC);\n" + "}\n"; + + NSError *error = nil; + NSString *source = [NSString stringWithUTF8String:src]; + id probe_library = [g_device newLibraryWithSource:source options:[MTLCompileOptions new] error:&error]; + if (!probe_library) { + fprintf(stderr, "ds4: Metal 4 tensor API probe compile failed: %s\n", + error ? [[error localizedDescription] UTF8String] : "(unknown)"); + return 0; + } + id fn = [probe_library newFunctionWithName:@"ds4_tensor_probe"]; + if (!fn) { + fprintf(stderr, "ds4: Metal 4 tensor API probe function missing\n"); + return 0; + } + error = nil; + id pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!pipeline) { + fprintf(stderr, "ds4: Metal 4 tensor API probe pipeline failed: %s\n", + error ? [[error localizedDescription] UTF8String] : "(unknown)"); + return 0; + } + return 1; + } +#endif + return 0; +} + +static void ds4_gpu_detect_metal4_features(void) { + g_metal4_runtime_available = 0; + g_metal4_family_supported = 0; + g_metal4_queue_supported = 0; + g_metal4_m5_neural_accelerators_hint = 0; + g_metal4_tensor_api_enabled = 0; + g_metal4_tensor_api_compile_supported = 0; + g_metal_device_name[0] = '\0'; + + if (!g_device) return; + + const char *name = [[g_device name] UTF8String]; + if (name) { + snprintf(g_metal_device_name, sizeof(g_metal_device_name), "%s", name); + } + +#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000 + if (@available(macOS 26.0, *)) { + g_metal4_runtime_available = 1; + g_metal4_family_supported = [g_device supportsFamily:MTLGPUFamilyMetal4] ? 1 : 0; + g_metal4_queue_supported = [g_device respondsToSelector:@selector(newMTL4CommandQueue)] ? 1 : 0; + + /* + * Apple does not currently expose a separate "Neural Accelerator" bit + * through Metal. On public M5 systems the hardware signal is the device + * generation plus Metal 4 support, so keep this as a conservative hint + * for diagnostics and future opt-in MPP/tensor kernels. + */ + if (g_metal4_family_supported && ds4_gpu_device_name_contains("M5")) { + g_metal4_m5_neural_accelerators_hint = 1; + } + + if (g_metal4_family_supported && getenv("DS4_METAL_TENSOR_DISABLE") == NULL) { + const int explicit_enable = getenv("DS4_METAL_TENSOR_ENABLE") != NULL; + const int default_enable = + ds4_gpu_device_name_contains("M5") || + ds4_gpu_device_name_contains("M6") || + ds4_gpu_device_name_contains("A19") || + ds4_gpu_device_name_contains("A20"); + + if (explicit_enable || default_enable) { + g_metal4_tensor_api_compile_supported = ds4_gpu_compile_tensor_probe(); + g_metal4_tensor_api_enabled = g_metal4_tensor_api_compile_supported; + if (!g_metal4_tensor_api_enabled) { + fprintf(stderr, "ds4: Metal 4 tensor API probe failed; using legacy Metal kernels\n"); + } + } else { + fprintf(stderr, "ds4: Metal 4 tensor API disabled for pre-M5/pre-A19 devices (set DS4_METAL_TENSOR_ENABLE=1 to experiment)\n"); + } + } + } +#endif +} + static int ds4_gpu_warm_model_views(void) { if (g_model_view_count == 0) return 1; @@ -1112,6 +1358,19 @@ void ds4_gpu_print_memory_report(const char *label) { "ds4: model residency requests %llu%s\n", (unsigned long long)g_model_residency_count, getenv("DS4_METAL_NO_RESIDENCY") != NULL ? " (disabled)" : ""); + fprintf(stderr, + "ds4: device %s, Metal 4 runtime %s, family %s, MTL4 queue %s, tensor API %s, M5 neural accelerators %s\n", + g_metal_device_name[0] ? g_metal_device_name : "(unknown)", + g_metal4_runtime_available ? "yes" : "no", + g_metal4_family_supported ? "yes" : "no", + g_metal4_queue_supported ? "yes" : "no", + g_metal4_tensor_api_enabled ? "enabled" : + (g_metal4_tensor_api_compile_supported ? "available" : "disabled"), + g_metal4_m5_neural_accelerators_hint ? "likely" : "not detected"); + fprintf(stderr, + "ds4: MPP Q8_0 prefill %s%s\n", + ds4_gpu_mpp_q8_0_policy_enabled() ? "enabled" : "disabled", + getenv("DS4_METAL_MPP_DISABLE") != NULL ? " (disabled by DS4_METAL_MPP_DISABLE)" : ""); fprintf(stderr, "ds4: scratch %.2f MiB (flash mask %.2f, pad %.2f, tmp %.2f, blk %.2f, ring %.2f, kv %.2f, compressor %.2f, router %.2f, indexer %.2f, moe %.2f, f16 %.2f, raw-store %.2f)\n", ds4_gpu_mib(scratch), @@ -1154,7 +1413,14 @@ void ds4_gpu_set_quality(bool quality) { static const char *ds4_gpu_source = "#include \n" +"#ifdef DS4_METAL_HAS_TENSOR\n" +"#include \n" +"#include \n" +"#endif\n" "using namespace metal;\n" +"#ifdef DS4_METAL_HAS_TENSOR\n" +"using namespace mpp::tensor_ops;\n" +"#endif\n" "\n" "#define MAX(x, y) ((x) > (y) ? (x) : (y))\n" "#define MIN(x, y) ((x) < (y) ? (x) : (y))\n" @@ -2191,6 +2457,17 @@ static int ds4_gpu_encode_attn_out_low_q8_direct( NSUInteger threadgroup_bytes, NSUInteger nsg); +static int ds4_gpu_encode_attn_out_low_q8_mpp( + id cb, + id pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off); + static ds4_gpu_mul_mm_id_map_args ds4_gpu_make_mul_mm_id_map_args( uint32_t src0_cols, uint32_t src0_experts, @@ -2654,6 +2931,13 @@ static int ds4_gpu_encode_rope_tail_inplace( float clamp_value; } ds4_gpu_dsv4_moe_swiglu_weight_args; +typedef struct { + uint32_t width; + uint32_t tokens; + uint64_t src_token_stride; + uint64_t dst_token_stride; +} ds4_gpu_dsv4_moe_sum6_args; + /* Compile the single in-repo Metal source and create the pipelines that every * session uses. Shape-dependent kernels with function constants are built * lazily by the small ds4_gpu_get_* caches, so startup stays predictable @@ -2668,6 +2952,7 @@ int ds4_gpu_init(void) { return 0; } ds4_gpu_print_device_summary(); + ds4_gpu_detect_metal4_features(); g_queue = [g_device newCommandQueue]; if (!g_queue) { @@ -2698,6 +2983,10 @@ int ds4_gpu_init(void) { return 0; } MTLCompileOptions *options = [MTLCompileOptions new]; + if (g_metal4_tensor_api_enabled) { + options.preprocessorMacros = @{ @"DS4_METAL_HAS_TENSOR": @"1" }; + fprintf(stderr, "ds4: Metal 4 tensor API enabled for MPP tensor kernels\n"); + } id library = [g_device newLibraryWithSource:source options:options error:&error]; if (!library) { fprintf(stderr, "ds4: Metal shader compilation failed: %s\n", @@ -2926,6 +3215,23 @@ int ds4_gpu_init(void) { return 0; } + fn = [library newFunctionWithName:@"kernel_dsv4_moe_sum6_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_moe_sum6_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + + g_moe_sum6_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_moe_sum6_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_moe_sum6_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + MTLFunctionConstantValues *bin_constants = [[MTLFunctionConstantValues alloc] init]; int16_t bin_op = 0; int16_t bin_f = 1; @@ -3971,6 +4277,7 @@ void ds4_gpu_cleanup(void) { g_cpy_f16_f32_pipeline = nil; g_swiglu_pipeline = nil; g_add_pipeline = nil; + g_moe_sum6_pipeline = nil; g_mul_pipeline = nil; g_bin_mul_scalar_pipeline = nil; g_bin_div_row_pipeline = nil; @@ -3999,9 +4306,6 @@ void ds4_gpu_cleanup(void) { g_moe_mul_mv_id_q4_k_pair_pipeline = nil; g_moe_mul_mv_id_q4_k_pair_swiglu_pipeline = nil; g_moe_mul_mv_id_q4_k_sum6_pipeline = nil; - g_moe_mul_mm_id_iq2_xxs_pipeline = nil; - g_moe_mul_mm_id_q2_k_pipeline = nil; - g_moe_mul_mm_id_q4_k_pipeline = nil; g_rope_tail_batch_pipeline = nil; g_dsv4_fp8_kv_quantize_pipeline = nil; g_dsv4_kv_fp8_store_pipeline = nil; @@ -4931,6 +5235,14 @@ int ds4_gpu_matmul_q8_0_tensor( return 0; } + if (n_tok > 8 && ds4_gpu_use_mpp_q8_0_matmul()) { + if (ds4_gpu_matmul_q8_0_mpp_tensor(out, model_map, model_size, weight_offset, + in_dim, out_dim, x, n_tok)) { + return 1; + } + ds4_gpu_warn_mpp_fallback(); + } + @autoreleasepool { id xbuf = ds4_gpu_tensor_buffer(x); id outbuf = ds4_gpu_tensor_buffer(out); @@ -5050,6 +5362,77 @@ int ds4_gpu_matmul_q8_0_tensor( return 1; } +int ds4_gpu_matmul_q8_0_mpp_tensor( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!g_metal4_tensor_api_enabled) return 0; + if ((in_dim & 31u) != 0 || n_tok <= 8 || + in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok > UINT32_MAX) { + return 0; + } + + @autoreleasepool { + id xbuf = ds4_gpu_tensor_buffer(x); + id outbuf = ds4_gpu_tensor_buffer(out); + const uint64_t x_bytes = n_tok * in_dim * sizeof(float); + const uint64_t out_bytes = n_tok * out_dim * sizeof(float); + if (!xbuf || !outbuf || + ds4_gpu_tensor_bytes(x) < x_bytes || + ds4_gpu_tensor_bytes(out) < out_bytes) { + fprintf(stderr, "ds4: Metal MPP Q8_0 matmul received undersized activation buffers\n"); + return 0; + } + + const uint64_t blocks = in_dim / 32; + const uint64_t row_bytes = blocks * 34; + const uint64_t weight_bytes = out_dim * row_bytes; + if (weight_offset > model_size || weight_bytes > model_size - weight_offset) { + fprintf(stderr, "ds4: Metal MPP Q8_0 matmul range is outside the mapped model\n"); + return 0; + } + + uint64_t inner_offset = 0; + id wbuf = ds4_gpu_wrap_model_range(model_map, model_size, weight_offset, weight_bytes, &inner_offset); + if (!wbuf) return 0; + + const bool bc_inp = (in_dim % 32u) != 0; + const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + id pipeline = + ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_q8_0_f32_mpp", bc_inp, bc_out); + if (!pipeline) return 0; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; + [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:2]; + [enc setBuffer:outbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; + [enc setThreadgroupMemoryLength:4096u atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + ((NSUInteger)out_dim + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal MPP Q8_0 matmul")) return 0; + } + + return 1; +} + int ds4_gpu_shared_gate_up_swiglu_q8_0_tensor( ds4_gpu_tensor *gate, ds4_gpu_tensor *up, @@ -5241,6 +5624,32 @@ int ds4_gpu_matmul_f16_tensor( const bool bc_inp = (in_dim % 32u) != 0; const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + /* Keep MPP F16 limited to the exact-safe ratio-2 compressor shape. */ + if (in_dim == 4096u && out_dim == 128u && !bc_inp && + ds4_gpu_use_mpp_f16_compressor_matmul()) { + id pipeline = + ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_f16_f32_mpp", false, bc_out); + if (pipeline) { + ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; + [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:2]; + [enc setBuffer:outbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; + [enc setThreadgroupMemoryLength:4096u atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + ((NSUInteger)out_dim + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal MPP F16 compressor matmul")) return 0; + return 1; + } + } + id pipeline = ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_f16_f32", bc_inp, bc_out); if (!pipeline) return 0; @@ -8001,9 +8410,14 @@ int ds4_gpu_attention_output_q8_batch_tensor( const bool use_direct_low = n_tokens < 32u && getenv("DS4_METAL_DISABLE_ATTN_OUT_LOW_DIRECT") == NULL; + /* The tensor tile store is only used on full token tiles; partial tails use the legacy path. */ + const bool use_mpp_low = + n_tokens >= 32u && + (n_tokens % 32u) == 0 && + ds4_gpu_use_mpp_attn_out_low_matmul(); const NSUInteger ids_bytes = (NSUInteger)n_tokens * (NSUInteger)n_groups * sizeof(int32_t); id group_ids_buffer = nil; - if (!use_direct_low) { + if (!use_direct_low && !use_mpp_low) { if (getenv("DS4_METAL_DISABLE_ATTN_OUT_IDS_CACHE") != NULL) { group_ids_buffer = ds4_gpu_new_transient_buffer(ids_bytes, "attention output group ids"); @@ -8073,7 +8487,73 @@ int ds4_gpu_attention_output_q8_batch_tensor( * tokens. This preserves the single-token generation path while * keeping prefill accumulation stable. */ - if (n_tokens >= 32u && ds4_gpu_mul_mm_id_map0_name(n_groups) != NULL) { + if (use_mpp_low) { + ds4_gpu_mul_mm_id_args mm_args = + ds4_gpu_make_mul_mm_id_args((uint32_t)group_dim, + (uint32_t)rank, + n_groups, + row_a_bytes, + (uint64_t)rank * row_a_bytes, + n_groups, + n_groups, + n_tokens); + id mm_pipeline = + ds4_gpu_get_mul_mm_id_pipeline("kernel_attn_out_low_q8_0_mpp", false, false); + ok = ds4_gpu_encode_attn_out_low_q8_mpp(cb, + mm_pipeline, + &mm_args, + out_a_buf, + (NSUInteger)out_a_inner, + ds4_gpu_tensor_buffer(heads), + ds4_gpu_tensor_offset(heads), + ds4_gpu_tensor_buffer(low), + ds4_gpu_tensor_offset(low)) != 0; + if (!ok) { + ds4_gpu_warn_mpp_fallback(); + if (ds4_gpu_mul_mm_id_map0_name(n_groups) != NULL) { + if (getenv("DS4_METAL_DISABLE_ATTN_OUT_IDS_CACHE") != NULL) { + group_ids_buffer = + ds4_gpu_new_transient_buffer(ids_bytes, "attention output group ids"); + } else if (ds4_gpu_ensure_scratch_buffer(&g_attn_out_group_ids_buffer, + &g_attn_out_group_ids_bytes, + ids_bytes, + "ds4_attention_output_group_ids")) { + group_ids_buffer = g_attn_out_group_ids_buffer; + } + if (group_ids_buffer) { + int32_t *ids = (int32_t *)[group_ids_buffer contents]; + for (uint32_t t = 0; t < n_tokens; t++) { + for (uint32_t group = 0; group < n_groups; group++) { + ids[(uint64_t)t * n_groups + group] = (int32_t)group; + } + } + ds4_gpu_mul_mm_id_map_args map_args = + ds4_gpu_make_mul_mm_id_map_args((uint32_t)group_dim, + n_groups, + n_groups, + n_groups, + n_tokens); + id map_pipeline = + ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_groups)); + id fallback_pipeline = + ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false, false); + ok = ds4_gpu_encode_mul_mm_id(cb, + map_pipeline, + fallback_pipeline, + &map_args, + &mm_args, + out_a_buf, + (NSUInteger)out_a_inner, + ds4_gpu_tensor_buffer(heads), + ds4_gpu_tensor_offset(heads), + ds4_gpu_tensor_buffer(low), + ds4_gpu_tensor_offset(low), + group_ids_buffer, + 0) != 0; + } + } + } + } else if (n_tokens >= 32u && ds4_gpu_mul_mm_id_map0_name(n_groups) != NULL) { ds4_gpu_mul_mm_id_map_args map_args = ds4_gpu_make_mul_mm_id_map_args((uint32_t)group_dim, n_groups, @@ -8092,7 +8572,7 @@ int ds4_gpu_attention_output_q8_batch_tensor( id map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_groups)); id mm_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false); + ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false, false); ok = ds4_gpu_encode_mul_mm_id(cb, map_pipeline, mm_pipeline, @@ -11590,39 +12070,27 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) { } } -static id ds4_gpu_routed_mm_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_pipeline(uint32_t type, bool use_mpp) { switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: - if (!g_moe_mul_mm_id_iq2_xxs_pipeline) { - g_moe_mul_mm_id_iq2_xxs_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32", false); - } - return g_moe_mul_mm_id_iq2_xxs_pipeline; + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32", false, use_mpp); case DS4_METAL_TENSOR_Q2_K: - if (!g_moe_mul_mm_id_q2_k_pipeline) { - g_moe_mul_mm_id_q2_k_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32", false); - } - return g_moe_mul_mm_id_q2_k_pipeline; + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32", false, use_mpp); case DS4_METAL_TENSOR_Q4_K: - if (!g_moe_mul_mm_id_q4_k_pipeline) { - g_moe_mul_mm_id_q4_k_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32", false); - } - return g_moe_mul_mm_id_q4_k_pipeline; + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32", false, use_mpp); default: return nil; } } -static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type, bool use_mpp) { switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16", false); + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16", false, use_mpp); case DS4_METAL_TENSOR_Q2_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false); + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false, use_mpp); case DS4_METAL_TENSOR_Q4_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false); + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false, use_mpp); default: return nil; } @@ -11960,6 +12428,37 @@ static int ds4_gpu_encode_mul_mm_id_mapped( return 1; } +static int ds4_gpu_encode_attn_out_low_q8_mpp( + id cb, + id pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off) { + if (!cb || !pipeline || !mm_args || !src0 || !src1 || !dst || + mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || + mm_args->ne02 <= 0 || mm_args->ne1 <= 0 || mm_args->ne21 <= 0) { + return 0; + } + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:mm_args length:sizeof(*mm_args) atIndex:0]; + [enc setBuffer:src0 offset:src0_off atIndex:1]; + [enc setBuffer:src1 offset:src1_off atIndex:2]; + [enc setBuffer:dst offset:dst_off atIndex:3]; + [enc setThreadgroupMemoryLength:4096u atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + 31u) / 32u, + ((NSUInteger)mm_args->ne0 + 63u) / 64u, + (NSUInteger)mm_args->ne02) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + static int ds4_gpu_encode_swiglu_flat( id cb, id gate, @@ -12050,6 +12549,42 @@ static int ds4_gpu_encode_moe_swiglu_weight( return 1; } +static int ds4_gpu_encode_moe_sum6( + id cb, + id experts, + NSUInteger experts_off, + id out, + NSUInteger out_off, + uint32_t out_dim, + uint32_t n_tokens) { + if (!cb || !experts || !out || out_dim == 0 || n_tokens == 0) return 0; + + if (!g_moe_sum6_pipeline) return 0; + + const uint64_t out_row_bytes = (uint64_t)out_dim * sizeof(float); + ds4_gpu_dsv4_moe_sum6_args args = { + .width = out_dim, + .tokens = n_tokens, + .src_token_stride = 6u * out_row_bytes, + .dst_token_stride = out_row_bytes, + }; + + NSUInteger nth = g_moe_sum6_pipeline.maxTotalThreadsPerThreadgroup; + if (nth > 256u) nth = 256u; + if (nth > out_dim) nth = out_dim; + if (nth == 0) nth = 1u; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_moe_sum6_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:experts offset:experts_off atIndex:1]; + [enc setBuffer:out offset:out_off atIndex:2]; + [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)n_tokens, 1, 1) + threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + static ds4_gpu_bin_args ds4_gpu_make_moe_add_args( uint32_t out_dim, uint32_t n_tokens, @@ -12100,6 +12635,18 @@ static int ds4_gpu_encode_moe_sum_experts( const uint64_t out_row_bytes = (uint64_t)out_dim * sizeof(float); const uint64_t expert_token_stride = (uint64_t)n_expert * out_row_bytes; + if (n_expert == 6 && + getenv("DS4_METAL_MOE_SUM6_DISABLE") == NULL && + ds4_gpu_encode_moe_sum6(cb, + experts, + experts_off, + out, + out_off, + out_dim, + n_tokens)) { + return 1; + } + ds4_gpu_bin_args first = ds4_gpu_make_moe_add_args(out_dim, n_tokens, expert_token_stride, expert_token_stride, out_row_bytes); if (!ds4_gpu_encode_bin_f32_rows(cb, @@ -13064,6 +13611,7 @@ int ds4_gpu_routed_moe_batch_tensor( uint32_t n_expert, float clamp, const ds4_gpu_tensor *x, + uint32_t layer_index, uint32_t n_tokens, bool *mid_is_f16) { if (!g_initialized && !ds4_gpu_init()) return 0; @@ -13130,6 +13678,7 @@ int ds4_gpu_routed_moe_batch_tensor( id gate_mv_pipeline = ds4_gpu_routed_mv_pipeline(gate_type); id down_mv_pipeline = ds4_gpu_routed_mv_pipeline(down_type); id gate_mm_pipeline = nil; + id up_mm_pipeline = nil; id down_mm_pipeline = nil; if (gate_nr0 == 0 || down_nr0 == 0 || !gate_mv_pipeline || !down_mv_pipeline) { fprintf(stderr, "ds4: unsupported Metal routed batch MoE quant types gate=%u down=%u\n", @@ -13166,6 +13715,7 @@ int ds4_gpu_routed_moe_batch_tensor( ds4_gpu_mul_mm_id_args gate_mm_args = { 0 }; ds4_gpu_mul_mm_id_args down_mm_args = { 0 }; id map_pipeline = nil; + const int moe_mpp_mask = ds4_gpu_mpp_routed_moe_mask_for_layer(layer_index); /* * The grouped routed-MoE matmul loads activation tiles as half before * using SIMD-group MMA. Store the SwiGLU/route-weight intermediate in @@ -13189,11 +13739,16 @@ int ds4_gpu_routed_moe_batch_tensor( request_mid_f16 ? sizeof(uint16_t) : sizeof(float)); map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_expert)); - gate_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type); + gate_mm_pipeline = ds4_gpu_routed_mm_pipeline( + gate_type, + (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0); + up_mm_pipeline = ds4_gpu_routed_mm_pipeline( + gate_type, + (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0); down_mm_pipeline = request_mid_f16 ? - ds4_gpu_routed_mm_f16_rhs_pipeline(down_type) : - ds4_gpu_routed_mm_pipeline(down_type); - if (!map_pipeline || !gate_mm_pipeline || !down_mm_pipeline) { + ds4_gpu_routed_mm_f16_rhs_pipeline(down_type, (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0) : + ds4_gpu_routed_mm_pipeline(down_type, (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0); + if (!map_pipeline || !gate_mm_pipeline || !up_mm_pipeline || !down_mm_pipeline) { return 0; } } @@ -13274,7 +13829,7 @@ int ds4_gpu_routed_moe_batch_tensor( } if (ok) { ok = ds4_gpu_encode_mul_mm_id_mapped(cb, - gate_mm_pipeline, + up_mm_pipeline, &gate_mm_args, up_buf, (NSUInteger)up_inner, diff --git a/metal/dense.metal b/metal/dense.metal index a84927e9..0d7af3ba 100644 --- a/metal/dense.metal +++ b/metal/dense.metal @@ -910,6 +910,105 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +#ifdef DS4_METAL_HAS_TENSOR +template< + typename SA, typename SA_4x4, typename block_q, short nl, + void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1> +kernel void kernel_mul_mm_mpp( + constant ds4_metal_args_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + const int im = tgpig.z; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + threadgroup SA *sa = (threadgroup SA *)shmem; + auto tA = tensor(sa, dextents(NK, NR0)); + + device T1 *ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11/sizeof(T1); + auto tB = tensor(ptrB, dextents(K, N), array({1, strideB})); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, true, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.get_destination_cooperative_tensor(); + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (r0 + row < M) { + if (is_same::value && FC_mul_mm_bc_inp) { + device const T0 *row_ptr = (device const T0 *)(srcA + args.nb01*(r0 + row) + offset0); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? (SA)row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos/(16*nl); + const short il = (k_pos/16)%nl; + device const block_q *row_ptr = (device const block_q *)(srcA + args.nb01*(r0 + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (SA)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, r1); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_batch = (device float *)dst + im*N*M; + auto tD = tensor(dst_batch, dextents(M, N), array({1, M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); +} + +typedef decltype(kernel_mul_mm_mpp) mul_mm_mpp_t; + +template [[host_name("kernel_mul_mm_f16_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp; +#endif + // Tiled matrix-matrix kernel used for prompt batches larger than 8. DS4 uses // this to turn prefill into large simdgroup matrix operations; each block_q // contains 16*nl weights. diff --git a/metal/moe.metal b/metal/moe.metal index 65074d7d..0cfd31ce 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -87,6 +87,8 @@ static constant ulong ds4_metal_iq2xxs_grid[256] = { 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, }; +constant bool FC_mul_mm_id_mpp [[function_constant(FC_MUL_MM + 2)]]; + #define kmask_iq2xs ds4_metal_kmask_iq2xs #define ksigns_iq2xs ds4_metal_ksigns_iq2xs #define iq2xxs_grid ds4_metal_iq2xxs_grid @@ -121,6 +123,13 @@ struct ds4_metal_dsv4_moe_swiglu_weight_args { float clamp_value; }; +struct ds4_metal_dsv4_moe_sum6_args { + uint32_t width; + uint32_t tokens; + uint64_t src_token_stride; + uint64_t dst_token_stride; +}; + // Routed-MoE activation for the selected experts: // clamp(gate), clamp(up), silu(gate) * up * route_weight. Normal inference // does not consume gate/up after this point, so the fast path avoids writing the @@ -198,6 +207,31 @@ kernel void kernel_dsv4_moe_swiglu_weight_f16( } } +kernel void kernel_dsv4_moe_sum6_f32( + constant ds4_metal_dsv4_moe_sum6_args &args, + device const char *src, + device char *dst, + uint token[[threadgroup_position_in_grid]], + uint tid[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + if (token >= args.tokens) return; + + device const float *s = + (device const float *)(src + (uint64_t)token * args.src_token_stride); + device float *d = + (device float *)(dst + (uint64_t)token * args.dst_token_stride); + + for (uint col = tid; col < args.width; col += ntg) { + float v = s[col]; + v += s[args.width + col]; + v += s[2u * args.width + col]; + v += s[3u * args.width + col]; + v += s[4u * args.width + col]; + v += s[5u * args.width + col]; + d[col] = v; + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -1530,6 +1564,9 @@ kernel void kernel_mul_mm_id( ushort sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef DS4_METAL_HAS_TENSOR + threadgroup float *sc = (threadgroup float *)shmem; +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -1588,6 +1625,17 @@ kernel void kernel_mul_mm_id( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } +#ifdef DS4_METAL_HAS_TENSOR + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NR1, NK)); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.get_destination_cooperative_tensor(); +#endif for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { if (is_same::value && FC_mul_mm_bc_inp) { @@ -1597,12 +1645,22 @@ kernel void kernel_mul_mm_id( const short sx = 2*il0 + i/8; const short sy = (tiitg/NL0)/8; +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } else +#endif + { const short lx = (tiitg/NL0)%8; const short ly = i%8; const short ib = 8*sx + sy; *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } } } else { S0_4x4 temp_a; @@ -1614,12 +1672,22 @@ kernel void kernel_mul_mm_id( const short sx = 2*il0 + i/8; const short sy = (tiitg/NL0)/8; +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; + } else +#endif + { const short lx = (tiitg/NL0)%8; const short ly = i%8; const short ib = 8*sx + sy; *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; + } } } @@ -1631,9 +1699,16 @@ kernel void kernel_mul_mm_id( const short lx = i; const short ly = (tiitg/NL1)%8; +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } else +#endif + { const short ib = 4*sx + sy; *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } } } else { const short sx = (tiitg%NL1); @@ -1641,9 +1716,16 @@ kernel void kernel_mul_mm_id( const short ly = (tiitg/NL1)%8; +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); + } else +#endif + { const short ib = 4*sx + sy; *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); + } } il = (il + 2 < nl) ? il + 2 : il % 2; @@ -1653,6 +1735,14 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + mm.run(sB, sA, cT); + } else +#endif + { threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); @@ -1678,15 +1768,24 @@ kernel void kernel_mul_mm_id( lsma += 8*64; lsmb += 4*64; } + } } threadgroup_barrier(mem_flags::mem_threadgroup); +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + auto tC = tensor(sc, dextents(NR0, NR1)); + cT.store(tC); + } else +#endif + { threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } + } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -1727,6 +1826,87 @@ template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; +#ifdef DS4_METAL_HAS_TENSOR +kernel void kernel_attn_out_low_q8_0_mpp( + constant ds4_metal_args_mul_mm_id & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne21; + const int G = args.ne1; + const int group = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + threadgroup half *sa = (threadgroup half *)shmem; + auto tA = tensor(sa, dextents(NK, NR0)); + + device float *ptrB = (device float *)(srcB + args.nb11*group); + const int strideB = args.nb12/sizeof(float); + auto tB = tensor(ptrB, dextents(K, N), array({1, strideB})); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, true, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.get_destination_cooperative_tensor(); + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (r0 + row < M) { + const int block_idx = k_pos/32; + const short il = (k_pos/16)%2; + device const block_q8_0 *row_ptr = + (device const block_q8_0 *)(srcA + args.nb01*(r0 + row) + group*args.nb02); + + half4x4 temp_a; + dequantize_q8_0(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? temp_a[i/4][i%4] : (half)0; + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (half)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, r1); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_group = (device float *)dst + group*M; + auto tD = tensor(dst_group, dextents(M, N), array({1, G*M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); +} +#endif + #undef QK_NL #undef kmask_iq2xs #undef ksigns_iq2xs diff --git a/tests/ds4_test.c b/tests/ds4_test.c index 959367c2..dd45ba78 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -150,6 +150,129 @@ static void test_metal_f16_matvec_fast_nr0_4(void) { free(weights_raw); } +static void test_metal_q8_0_mpp_matmul(void) { + const uint32_t in_dim = 128; + const uint32_t out_dim = 96; + const uint32_t n_tok = 48; + const uint64_t blocks = in_dim / 32; + const uint64_t row_bytes = blocks * 34; + const uint64_t weight_bytes = (uint64_t)out_dim * row_bytes; + const uint64_t weight_alloc = test_round_up_u64(weight_bytes, (uint64_t)getpagesize()); + + void *weights_raw = NULL; + TEST_ASSERT(posix_memalign(&weights_raw, (size_t)getpagesize(), (size_t)weight_alloc) == 0); + if (!weights_raw) return; + + uint8_t *weights = weights_raw; + memset(weights, 0, (size_t)weight_alloc); + for (uint32_t o = 0; o < out_dim; o++) { + for (uint32_t b = 0; b < blocks; b++) { + uint8_t *block = weights + (uint64_t)o * row_bytes + (uint64_t)b * 34u; + uint16_t d = test_float_to_f16((float)((o + b) % 5u + 1u) / 128.0f); + memcpy(block, &d, sizeof(d)); + int8_t *qs = (int8_t *)(block + 2); + for (uint32_t i = 0; i < 32; i++) { + qs[i] = (int8_t)((int)((o * 5u + b * 7u + i * 3u) % 63u) - 31); + } + } + } + + const uint64_t x_bytes = (uint64_t)n_tok * in_dim * sizeof(float); + const uint64_t out_bytes = (uint64_t)n_tok * out_dim * sizeof(float); + ds4_gpu_tensor *x = ds4_gpu_tensor_alloc(x_bytes); + ds4_gpu_tensor *out_ref = ds4_gpu_tensor_alloc(out_bytes); + ds4_gpu_tensor *out_mpp = ds4_gpu_tensor_alloc(out_bytes); + TEST_ASSERT(x != NULL); + TEST_ASSERT(out_ref != NULL); + TEST_ASSERT(out_mpp != NULL); + if (!x || !out_ref || !out_mpp) { + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); + return; + } + + float *x_host = malloc((size_t)x_bytes); + float *ref_host = malloc((size_t)out_bytes); + float *mpp_host = malloc((size_t)out_bytes); + TEST_ASSERT(x_host != NULL); + TEST_ASSERT(ref_host != NULL); + TEST_ASSERT(mpp_host != NULL); + if (!x_host || !ref_host || !mpp_host) { + free(x_host); + free(ref_host); + free(mpp_host); + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); + return; + } + + for (uint32_t t = 0; t < n_tok; t++) { + for (uint32_t i = 0; i < in_dim; i++) { + x_host[(uint64_t)t * in_dim + i] = + (float)((int)((t * 19u + i * 23u) % 53u) - 26) / 80.0f; + } + } + + TEST_ASSERT(ds4_gpu_tensor_write(x, 0, x_host, x_bytes) != 0); + TEST_ASSERT(ds4_gpu_set_model_map(weights_raw, weight_alloc) != 0); + ds4_gpu_set_quality(false); + TEST_ASSERT(ds4_gpu_matmul_q8_0_tensor(out_ref, weights_raw, weight_alloc, 0, + in_dim, out_dim, x, n_tok) != 0); + + int have_mpp = ds4_gpu_matmul_q8_0_mpp_tensor( + out_mpp, weights_raw, weight_alloc, 0, in_dim, out_dim, x, n_tok); + if (!have_mpp) { + fprintf(stderr, "ds4-test: skipping MPP Q8_0 matmul; Metal 4 tensor API unavailable\n"); + free(x_host); + free(ref_host); + free(mpp_host); + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); + return; + } + + TEST_ASSERT(ds4_gpu_tensor_read(out_ref, 0, ref_host, out_bytes) != 0); + TEST_ASSERT(ds4_gpu_tensor_read(out_mpp, 0, mpp_host, out_bytes) != 0); + + float max_abs = 0.0f; + uint64_t max_index = 0; + for (uint64_t i = 0; i < (uint64_t)n_tok * out_dim; i++) { + float err = fabsf(mpp_host[i] - ref_host[i]); + if (err > max_abs) { + max_abs = err; + max_index = i; + } + } + if (max_abs >= 0.10f) { + fprintf(stderr, "ds4-test: MPP Q8_0 matmul max_abs=%f at token=%llu out=%llu ref=%f mpp=%f\n", + max_abs, + (unsigned long long)(max_index / out_dim), + (unsigned long long)(max_index % out_dim), + ref_host[max_index], + mpp_host[max_index]); + } + TEST_ASSERT(max_abs < 0.10f); + + free(x_host); + free(ref_host); + free(mpp_host); + ds4_gpu_tensor_free(x); + ds4_gpu_tensor_free(out_ref); + ds4_gpu_tensor_free(out_mpp); + free(weights_raw); +} + +static void test_metal_kernel_group(void) { + test_metal_f16_matvec_fast_nr0_4(); + test_metal_q8_0_mpp_matmul(); +} + static char *test_read_file(const char *path) { FILE *fp = fopen(path, "rb"); if (!fp) return NULL; @@ -650,7 +773,7 @@ static const ds4_test_entry test_entries[] = { {"--long-context", "long-context", "long-context story fact-recall regression", test_long_story_fact_recall}, {"--tool-call-quality", "tool-call-quality", "model emits valid DSML tool calls", test_tool_call_quality}, {"--logprob-vectors", "logprob-vectors", "official API top-logprob vector comparison", test_official_logprob_vectors}, - {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_f16_matvec_fast_nr0_4}, + {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_kernel_group}, #endif {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, }; From a50dd90c0ebe3d01cd45cd31b303c5ad91fa3257 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Sun, 10 May 2026 23:40:55 +0200 Subject: [PATCH 02/16] Improve Metal MPP diagnostics and safe defaults --- README.md | 164 ++++- ds4.c | 411 ++++++++---- ds4.h | 10 + ds4_cli.c | 15 +- ds4_gpu.h | 5 + ds4_metal.m | 1539 +++++++++++++++++++++++++++++++++++++++++---- ds4_server.c | 15 +- metal/dense.metal | 493 ++++++++++++++- metal/moe.metal | 632 +++++++++++++++++-- tests/ds4_test.c | 589 ++++++++++++++++- 10 files changed, 3563 insertions(+), 310 deletions(-) diff --git a/README.md b/README.md index 63a91e88..3667471d 100644 --- a/README.md +++ b/README.md @@ -220,31 +220,156 @@ tensor matmul probe before it lets the main Metal shader source see `DS4_METAL_HAS_TENSOR`, so unsupported SDK/device combinations fall back to the legacy kernels. -The Q8_0 prefill MPP route is enabled automatically on M5/M6/A19/A20-class -Metal 4 tensor targets and can be forced with -`DS4_METAL_MPP_ENABLE=1 ./ds4 --prompt-file README.md`. It only affects prompt -batches larger than eight tokens, falls back to the legacy kernel if the Metal 4 -tensor path is unavailable, and is covered by the isolated -`./ds4_test --metal-kernels` numeric regression. It has also passed the -long-context and official logprob-vector regressions on M5. Set -`DS4_METAL_MPP_DISABLE=1` to compare or temporarily disable the MPP route. - -The routed-MoE projections also use MPP by default on M5-class Metal 4 tensor -targets for staged prefill layers: the down projection starts at layer 2, the -gate and up projections start at layer 13. This constrained route has passed -the long-context and official logprob-vector regressions. Starting down at -layer 1, or gate/up together at layer 12, fails the long-context regression, -so the boundaries are intentionally conservative. +MPP policy is explicit and correctness-first. Use `--mpp auto` for the default +route policy, `--mpp on` to force MPP routes where the Metal 4 tensor path is +available, and `--mpp off` for the legacy Metal reference path. Auto currently +enables only the validated late-layer safe windows that pass full-model +equivalence and clear the benchmark gate; early-layer and all-layer MPP routes +remain opt-in diagnostics. The environment controls +`DS4_METAL_MPP_ENABLE` and `DS4_METAL_MPP_DISABLE` accept `1/true/yes/on` and +`0/false/no/off`; `DS4_METAL_MPP_ENABLE=0` disables MPP instead of enabling it +by mere presence. Passing `--quality` also disables MPP routes so strict/debug +runs stay on the legacy Metal kernels. Set `DS4_METAL_MPP_FAST=1` to opt into +the current same-top1/same-greedy fast profile: it widens Q8_0 and +attention-output MPP to all layers, enables Q8_0 partial token tiles, and uses +earlier routed-MoE MPP windows. This profile is not the default because its +whole-vocab and top-k drift are much larger than the correctness-first auto +profile. +Set `DS4_METAL_MPP_DIRECT_RHS=1` only for diagnostics of the first-PR MPP +direct-RHS tensor layout; it is not part of the correctness-first default. Q8_0 +and attention-output direct-RHS diagnostics support both 32-token and 64-token +MPP tiles, so they can be combined with `DS4_METAL_MPP_Q8_0_TILE_N=64` and +`DS4_METAL_MPP_ATTN_OUT_TILE_N=64` for M5 throughput experiments. The +route-specific `DS4_METAL_MPP_Q8_0_DIRECT_RHS=1`, +`DS4_METAL_MPP_F16_DIRECT_RHS=1`, and +`DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS=1` switches isolate that diagnostic layout +without turning on every direct-RHS route at once. + +The Q8_0 prefill MPP route can be isolated with +`DS4_METAL_MPP_Q8_0_ENABLE=1` or `DS4_METAL_MPP_Q8_0_DISABLE=1`. It only +affects prompt batches larger than eight tokens and is limited by default to +the late full-model-safe layer window 38..42, plus the `attn_q_b` projection in +layers 32..37. It uses only full 32-token tiles by default and falls back to the +legacy kernel for partial token tiles or when the Metal 4 tensor path is +unavailable. Set +`DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=1` to reproduce or localize partial-tile +drift while debugging. Set `DS4_METAL_MPP_Q8_0_FILTER=all` to reproduce the +unsafe all-layer Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=late_safe` to request the +default safe window explicitly, or +`DS4_METAL_MPP_Q8_0_FILTER=` to force named +full-graph Q8 modules such as `attn_q_a`, `attn_kv`, `attn_q_b`, `attn_out`, +`shared_gate`, `shared_up`, or `shared_down`. Use +`@layer=A..B` to test one module family only in a layer window, for +example `shared_up@layer=30..37`. Set +`DS4_METAL_MPP_Q8_0_TILE_N=64` to test the experimental wider MPP token tile +for performance against the default `32`. The isolated +`./ds4_test --metal-kernels` regression reports small/medium/model-ish kernel +deltas; the full-model +`./ds4_test --metal-mpp-equivalence` diagnostic compares default auto against +`--mpp off`. Set `DS4_TEST_MPP_EQ_FORCE_ON=1` to compare forced MPP against +`--mpp off` while working on a route. `DS4_TEST_MPP_EQ_CASE=` +limits the diagnostic to one prompt, and `DS4_TEST_MPP_EQ_MATRIX=1` prints +separate auto, fast-profile, Q8-only, attention-output-only, MoE gate/up/down-only, +and full-forced summary rows. The equivalence gate requires finite logits, the +same top-1 token, and matching greedy continuation; it also reports top-5/top-20 +overlap, top-20 rank displacement, top-20 logit deltas, and whole-vocab RMS/max +drift so route changes can be judged beyond pass/fail. + +Full-graph route localization is available with +`DS4_METAL_MPP_COMPARE_ROUTE=q8|attn_out|moe_gate|moe_up|moe_down` and optional +`DS4_METAL_MPP_COMPARE_MAX=N`. The comparator snapshots the candidate MPP +output, runs the legacy Metal route on the same tensor input, and reports the +first comparison that exceeds the kernel target, including module/layer context, +shape, max absolute error, RMS, and the largest element deltas. Set +`DS4_METAL_MPP_COMPARE_VERBOSE=1` to print passing comparisons as well. + +Current MPP route status is intentionally conservative: `auto` enables Q8_0 +prefill, F16 compressor, attention-output low projection, and routed-MoE MPP +only in the full-model-safe windows. Attention-output low projection now uses +layers 32..42 by default, while Q8_0 keeps one narrower `attn_q_b` extension +for layers 32..37. The Q8_0 and attention-output low MPP +kernels stage activation tiles through half to match the legacy Metal matmul +input path, which brings the isolated model-ish Q8_0 regression under the +strict kernel target and removes the first attention-output comparator breach. +Most Q8_0 projection families stay restricted to layers 38..42 because earlier +layers can amplify small local differences through normalization/attention +enough to fail prompt-logit equivalence. The `attn_q_b` 32..37 extension is +kept because it is query-side only for full prompt tiles in the current +validation path, passes prompt-logit equivalence, and improves prefill +throughput. The F16 compressor route did not introduce measurable drift in the +current prompt set. + +The `DS4_METAL_MPP_FAST=1` profile is the measured high-throughput diagnostic +profile under the relaxed same-top1/same-greedy gate. In the current prompt +suite it keeps top-1 and greedy continuations stable, but reports much larger +distribution drift than auto (`worst_rms ~= 0.761`, +`worst_top20_max_abs ~= 2.28`, minimum top-20 overlap `18/20`). On the +long-code prefill benchmark it sampled around `360 t/s` in the same window +where auto sampled around `318 t/s`; benchmark variance is high when the +desktop is active. The more aggressive direct-RHS 64-token diagnostic +(`DS4_METAL_MPP_FAST=1 DS4_METAL_MPP_DIRECT_RHS=1 +DS4_METAL_MPP_Q8_0_TILE_N=64 DS4_METAL_MPP_ATTN_OUT_TILE_N=64`) passed the +relaxed top-1/greedy gate and `--logprob-vectors`, and in Automatic power mode +sampled around `324 t/s` versus `289 t/s` for auto in the same short benchmark +window. It remains diagnostic-only because its full-suite drift is higher +(`worst_rms ~= 0.846`, `worst_top20_max_abs ~= 2.07`, minimum top-20 overlap +`16/20`). + +The routed-MoE MPP projections are staged when forced and are limited to a +late full-model-safe layer window by default: gate/down start at layer 28, and +up starts at layer 30. For route isolation, use +`DS4_METAL_MPP_MOE_GATE_ENABLE/DISABLE`, +`DS4_METAL_MPP_MOE_UP_ENABLE/DISABLE`, and +`DS4_METAL_MPP_MOE_DOWN_ENABLE/DISABLE`; `DS4_METAL_MPP_MOE_DISABLE=1` +disables all routed-MoE MPP projections. Set the common +`DS4_METAL_MPP_MOE_FILTER` or route-specific +`DS4_METAL_MPP_MOE_GATE_FILTER`, `DS4_METAL_MPP_MOE_UP_FILTER`, and +`DS4_METAL_MPP_MOE_DOWN_FILTER` to `all`, `late_safe`, `none`, or +comma-separated full-graph context substrings to localize safe layer windows. +Use `layer=N` for an exact layer match or `layer=A..B` for an inclusive layer +range when testing sparse MPP windows. The same `@layer=A..B` +syntax can restrict a context substring to a layer window. +Set `DS4_METAL_MPP_MOE_TILE_N=64` to test the experimental wider routed-MoE +MPP token tile for performance against the default `32`. Set +`DS4_METAL_MPP_MOE_FAST_LAYOUT=1` to test the old first-PR routed-MoE MPP +threadgroup tensor layout as an explicit performance diagnostic. Set +`DS4_METAL_MPP_MOE_START_LAYER=N`, or the route-specific +`DS4_METAL_MPP_MOE_GATE_START_LAYER`, +`DS4_METAL_MPP_MOE_UP_START_LAYER`, and +`DS4_METAL_MPP_MOE_DOWN_START_LAYER`, to test earlier routed-MoE MPP start +layers before changing the conservative defaults. Set +`DS4_METAL_MPP_MOE_PAIR_GATE_UP=1` only to profile the experimental fused +gate/up MPP dispatch; it passes the current equivalence gate but is not a +default path because it is slower than separate gate and up dispatches. For the common six-routed-expert prefill shape, the down-projection expert outputs are summed with a single Metal kernel instead of five chained add passes. Set `DS4_METAL_MOE_SUM6_DISABLE=1` to compare or temporarily disable that fused sum route. -The attention-output low-projection also uses MPP by default on Metal 4 tensor -targets for full 32-token tiles, falling back to the existing indexed simdgroup -kernel for partial tiles. Set `DS4_METAL_MPP_ATTN_OUT_DISABLE=1` to isolate or -temporarily disable this route. +The attention-output low-projection MPP route applies to full 32-token tiles +in the default safe window, falling back to the existing indexed simdgroup +kernel for partial tiles. Attention-output MPP is limited to the measured +full-model-safe layer window 32..42 by default. Set +`DS4_METAL_MPP_ATTN_OUT_ENABLE=1` or `DS4_METAL_MPP_ATTN_OUT_DISABLE=1` to +isolate this route. Set `DS4_METAL_MPP_ATTN_OUT_FILTER=all`, `late_safe`, +`none`, or a comma-separated list of full-graph context substrings such as +`layer=42` to localize full-model-safe layer windows. Layer filters are exact, +and `layer=A..B` matches an inclusive range. Set +`DS4_METAL_MPP_ATTN_OUT_TILE_N=64` to test the experimental wider MPP token +tile for performance against the default `32`. The all-layer +attention-output MPP route still fails long-prompt full-model equivalence +despite per-layer low-projection differences below the current kernel target. +The ratio-2 F16 compressor route can similarly be controlled with +`DS4_METAL_MPP_F16_ENABLE=1` or `DS4_METAL_MPP_F16_DISABLE=1`. +`DS4_METAL_MPP_F16_PAIR=1` tests a paired KV/gate compressor dispatch that keeps +the standard simdgroup F16 matmul accumulation shape. It passes the current +full-model equivalence gate, but the measured long-code prefill change was +within noise (`~0.4%`), so it remains opt-in. `DS4_METAL_MPP_F16_WIDE=1` tests +wider 512/1024-column compressor MPP, including the paired MPP route when both +variables are set. The wide route is diagnostic only: the current long-code +prompt fails full-model equivalence with wide F16 MPP (`rms ~= 0.569`, +`top20_max_abs ~= 1.48`), so it is not enabled by `auto`. ## CLI @@ -757,6 +882,7 @@ All project tests are driven by the C runner: ```sh make test # ./ds4_test --all ./ds4_test --logprob-vectors +./ds4_test --metal-mpp-equivalence ./ds4_test --server ``` diff --git a/ds4.c b/ds4.c index c0866bc3..64aec52b 100644 --- a/ds4.c +++ b/ds4.c @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -9972,6 +9973,30 @@ static bool metal_graph_matmul_plain_tensor( return false; } +static bool metal_graph_matmul_q8_0_named_tensor( + const char *module, + uint32_t il, + uint32_t pos0, + ds4_gpu_tensor *out, + const ds4_model *model, + const ds4_tensor *w, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + ds4_gpu_set_mpp_compare_context(module, il, pos0); + const bool ok = ds4_gpu_matmul_q8_0_tensor(out, + model->map, + model->size, + w->abs_offset, + in_dim, + out_dim, + x, + n_tok) != 0; + ds4_gpu_clear_mpp_compare_context(); + return ok; +} + static bool metal_graph_encode_output_head_mtp( ds4_gpu_graph *g, const ds4_model *base_model, @@ -10970,6 +10995,66 @@ static bool metal_graph_q_stage_profile_boundary( return ds4_gpu_begin_commands() != 0; } +static bool ds4_env_bool_enabled(const char *name) { + const char *v = getenv(name); + if (!v) return false; + + while (isspace((unsigned char)*v)) v++; + size_t n = strlen(v); + while (n > 0 && isspace((unsigned char)v[n - 1])) n--; + if (n == 0) return true; + + if ((n == 1 && v[0] == '0') || + (n == 2 && strncasecmp(v, "no", n) == 0) || + (n == 3 && strncasecmp(v, "off", n) == 0) || + (n == 5 && strncasecmp(v, "false", n) == 0)) { + return false; + } + return true; +} + +static bool metal_graph_matmul_f16_pair_or_separate( + ds4_gpu_tensor *out_a, + ds4_gpu_tensor *out_b, + const ds4_model *model, + uint64_t weight_a_offset, + uint64_t weight_b_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tokens) { + if (ds4_env_bool_enabled("DS4_METAL_MPP_F16_PAIR")) { + if (ds4_gpu_matmul_f16_pair_tensor(out_a, + out_b, + model->map, + model->size, + weight_a_offset, + weight_b_offset, + in_dim, + out_dim, + x, + n_tokens) != 0) { + return true; + } + } + return ds4_gpu_matmul_f16_tensor(out_a, + model->map, + model->size, + weight_a_offset, + in_dim, + out_dim, + x, + n_tokens) != 0 && + ds4_gpu_matmul_f16_tensor(out_b, + model->map, + model->size, + weight_b_offset, + in_dim, + out_dim, + x, + n_tokens) != 0; +} + static bool metal_graph_encode_layer_attention_batch( ds4_gpu_graph *g, const ds4_model *model, @@ -11085,28 +11170,32 @@ static bool metal_graph_encode_layer_attention_batch( } DS4_METAL_PROFILE_ATTN_STAGE("norm"); DS4_METAL_PROFILE_Q_STAGE("pre_q"); - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_qr, - model->map, - model->size, - layer->attn_q_a->abs_offset, - DS4_N_EMBD, - q_rank, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_q_a", + il, + pos0, + g->batch_qr, + model, + layer->attn_q_a, + DS4_N_EMBD, + q_rank, + g->batch_attn_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("q_lora", g->batch_qr, (uint64_t)n_tokens * q_rank, il, pos0); } DS4_METAL_PROFILE_Q_STAGE("q_a"); if (qkv_rms_fused) { - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_kv_raw, - model->map, - model->size, - layer->attn_kv->abs_offset, - DS4_N_EMBD, - DS4_N_HEAD_DIM, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_kv", + il, + pos0, + g->batch_kv_raw, + model, + layer->attn_kv, + DS4_N_EMBD, + DS4_N_HEAD_DIM, + g->batch_attn_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("KVraw", g->batch_kv_raw, (uint64_t)n_tokens * DS4_N_HEAD_DIM, il, pos0); @@ -11142,14 +11231,16 @@ static bool metal_graph_encode_layer_attention_batch( (uint64_t)n_tokens * DS4_N_HEAD_DIM, il, pos0); } DS4_METAL_PROFILE_Q_STAGE("q_a_norm"); - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_q, - model->map, - model->size, - layer->attn_q_b->abs_offset, - q_rank, - q_dim, - g->batch_qr_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_q_b", + il, + pos0, + g->batch_q, + model, + layer->attn_q_b, + q_rank, + q_dim, + g->batch_qr_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("Qraw", g->batch_q, (uint64_t)n_tokens * q_dim, il, pos0); @@ -11186,14 +11277,16 @@ static bool metal_graph_encode_layer_attention_batch( DS4_METAL_PROFILE_Q_STAGE("rope"); DS4_METAL_PROFILE_ATTN_STAGE("q_path"); if (!qkv_rms_fused) { - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_kv_raw, - model->map, - model->size, - layer->attn_kv->abs_offset, - DS4_N_EMBD, - DS4_N_HEAD_DIM, - g->batch_attn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_kv", + il, + pos0, + g->batch_kv_raw, + model, + layer->attn_kv, + DS4_N_EMBD, + DS4_N_HEAD_DIM, + g->batch_attn_norm, + n_tokens); if (ok) { metal_graph_debug_dump_tensor("KVraw", g->batch_kv_raw, (uint64_t)n_tokens * DS4_N_HEAD_DIM, il, pos0); @@ -11320,27 +11413,39 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal layer-major prefill needs attention compressor weights\n"); ok = false; } - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, - model->map, - model->size, - layer->attn_compressor_kv->abs_offset, - DS4_N_EMBD, - comp_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok && ds4_env_bool_enabled("DS4_METAL_MPP_F16_PAIR")) { + ok = metal_graph_matmul_f16_pair_or_separate(g->batch_comp_kv, + g->batch_comp_sc, + model, + layer->attn_compressor_kv->abs_offset, + layer->attn_compressor_gate->abs_offset, + DS4_N_EMBD, + comp_width, + g->batch_attn_norm, + n_tokens); + } else if (ok) { + ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, + model->map, + model->size, + layer->attn_compressor_kv->abs_offset, + DS4_N_EMBD, + comp_width, + g->batch_attn_norm, + n_tokens) != 0; + if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, + model->map, + model->size, + layer->attn_compressor_gate->abs_offset, + DS4_N_EMBD, + comp_width, + g->batch_attn_norm, + n_tokens) != 0; + } if (ok) metal_graph_debug_dump_tensor("attn_comp_kv_raw", g->batch_comp_kv, (uint64_t)comp_width * n_tokens, il, pos0); - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, - model->map, - model->size, - layer->attn_compressor_gate->abs_offset, - DS4_N_EMBD, - comp_width, - g->batch_attn_norm, - n_tokens) != 0; if (ok) metal_graph_debug_dump_tensor("attn_comp_score_raw", g->batch_comp_sc, (uint64_t)comp_width * n_tokens, @@ -11598,27 +11703,39 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal layer-major prefill needs indexer weights\n"); ok = false; } - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, - model->map, - model->size, - layer->indexer_compressor_kv->abs_offset, - DS4_N_EMBD, - index_width, - g->batch_attn_norm, - n_tokens) != 0; + if (ok && ds4_env_bool_enabled("DS4_METAL_MPP_F16_PAIR")) { + ok = metal_graph_matmul_f16_pair_or_separate(g->batch_comp_kv, + g->batch_comp_sc, + model, + layer->indexer_compressor_kv->abs_offset, + layer->indexer_compressor_gate->abs_offset, + DS4_N_EMBD, + index_width, + g->batch_attn_norm, + n_tokens); + } else if (ok) { + ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_kv, + model->map, + model->size, + layer->indexer_compressor_kv->abs_offset, + DS4_N_EMBD, + index_width, + g->batch_attn_norm, + n_tokens) != 0; + if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, + model->map, + model->size, + layer->indexer_compressor_gate->abs_offset, + DS4_N_EMBD, + index_width, + g->batch_attn_norm, + n_tokens) != 0; + } if (ok) metal_graph_debug_dump_tensor("indexer_comp_kv_raw", g->batch_comp_kv, (uint64_t)index_width * n_tokens, il, pos0); - if (ok) ok = ds4_gpu_matmul_f16_tensor(g->batch_comp_sc, - model->map, - model->size, - layer->indexer_compressor_gate->abs_offset, - DS4_N_EMBD, - index_width, - g->batch_attn_norm, - n_tokens) != 0; if (ok) metal_graph_debug_dump_tensor("indexer_comp_score_raw", g->batch_comp_sc, (uint64_t)index_width * n_tokens, @@ -12237,20 +12354,24 @@ static bool metal_graph_encode_layer_attention_batch( (uint64_t)n_tokens * q_dim, il, pos0); } DS4_METAL_PROFILE_ATTN_STAGE("inv_rope"); - if (ok) ok = ds4_gpu_attention_output_q8_batch_tensor(g->batch_attn_out, - g->batch_attn_low, - g->batch_group_tmp, - g->batch_low_tmp, - model->map, - model->size, - layer->attn_output_a->abs_offset, - layer->attn_output_b->abs_offset, - group_dim, - rank, - n_groups, - DS4_N_EMBD, - g->batch_heads, - n_tokens) != 0; + if (ok) { + ds4_gpu_set_mpp_compare_context("attn_out", il, pos0); + ok = ds4_gpu_attention_output_q8_batch_tensor(g->batch_attn_out, + g->batch_attn_low, + g->batch_group_tmp, + g->batch_low_tmp, + model->map, + model->size, + layer->attn_output_a->abs_offset, + layer->attn_output_b->abs_offset, + group_dim, + rank, + n_groups, + DS4_N_EMBD, + g->batch_heads, + n_tokens) != 0; + ds4_gpu_clear_mpp_compare_context(); + } if (ok) { metal_graph_debug_dump_tensor("attn_low", g->batch_attn_low, (uint64_t)n_tokens * n_groups * rank, @@ -12422,33 +12543,37 @@ static bool metal_graph_encode_layer_ffn_batch( } DS4_METAL_PROFILE_FFN_STAGE("router"); - if (ok) ok = ds4_gpu_routed_moe_batch_tensor(g->batch_routed_out, - g->batch_routed_gate, - g->batch_routed_up, - g->batch_routed_mid, - g->batch_routed_down, - model->map, - model->size, - layer->ffn_gate_exps->abs_offset, - layer->ffn_up_exps->abs_offset, - layer->ffn_down_exps->abs_offset, - layer->ffn_gate_exps->type, - layer->ffn_down_exps->type, - gate_expert_bytes, - gate_row_bytes, - down_expert_bytes, - down_row_bytes, - (uint32_t)expert_in_dim, - (uint32_t)down_in_dim, - (uint32_t)routed_out_dim, - g->batch_router_selected, - g->batch_router_weights, - DS4_N_EXPERT_USED, - DS4_SWIGLU_CLAMP_EXP, - g->batch_ffn_norm, - il, - n_tokens, - &g->batch_routed_mid_is_f16) != 0; + if (ok) { + ds4_gpu_set_mpp_compare_context("routed_moe", il, pos0); + ok = ds4_gpu_routed_moe_batch_tensor(g->batch_routed_out, + g->batch_routed_gate, + g->batch_routed_up, + g->batch_routed_mid, + g->batch_routed_down, + model->map, + model->size, + layer->ffn_gate_exps->abs_offset, + layer->ffn_up_exps->abs_offset, + layer->ffn_down_exps->abs_offset, + layer->ffn_gate_exps->type, + layer->ffn_down_exps->type, + gate_expert_bytes, + gate_row_bytes, + down_expert_bytes, + down_row_bytes, + (uint32_t)expert_in_dim, + (uint32_t)down_in_dim, + (uint32_t)routed_out_dim, + g->batch_router_selected, + g->batch_router_weights, + DS4_N_EXPERT_USED, + DS4_SWIGLU_CLAMP_EXP, + g->batch_ffn_norm, + il, + n_tokens, + &g->batch_routed_mid_is_f16) != 0; + ds4_gpu_clear_mpp_compare_context(); + } if (ok) { metal_graph_debug_dump_tensor("ffn_moe_gate_clamped", g->batch_routed_gate, (uint64_t)n_tokens * DS4_N_EXPERT_USED * down_in_dim, il, pos0); @@ -12468,22 +12593,26 @@ static bool metal_graph_encode_layer_ffn_batch( (uint64_t)n_tokens * DS4_N_EMBD, il, pos0); } DS4_METAL_PROFILE_FFN_STAGE("routed_moe"); - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_shared_gate, - model->map, - model->size, - layer->ffn_gate_shexp->abs_offset, - DS4_N_EMBD, - shared_dim, - g->batch_ffn_norm, - n_tokens) != 0; - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_shared_up, - model->map, - model->size, - layer->ffn_up_shexp->abs_offset, - DS4_N_EMBD, - shared_dim, - g->batch_ffn_norm, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("shared_gate", + il, + pos0, + g->batch_shared_gate, + model, + layer->ffn_gate_shexp, + DS4_N_EMBD, + shared_dim, + g->batch_ffn_norm, + n_tokens); + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("shared_up", + il, + pos0, + g->batch_shared_up, + model, + layer->ffn_up_shexp, + DS4_N_EMBD, + shared_dim, + g->batch_ffn_norm, + n_tokens); DS4_METAL_PROFILE_FFN_STAGE("shared_gate_up"); if (ok) ok = ds4_gpu_swiglu_tensor(g->batch_shared_mid, g->batch_shared_gate, @@ -12491,14 +12620,16 @@ static bool metal_graph_encode_layer_ffn_batch( (uint32_t)((uint64_t)n_tokens * shared_dim), 0.0f, 1.0f) != 0; - if (ok) ok = ds4_gpu_matmul_q8_0_tensor(g->batch_shared_out, - model->map, - model->size, - layer->ffn_down_shexp->abs_offset, - shared_dim, - DS4_N_EMBD, - g->batch_shared_mid, - n_tokens) != 0; + if (ok) ok = metal_graph_matmul_q8_0_named_tensor("shared_down", + il, + pos0, + g->batch_shared_out, + model, + layer->ffn_down_shexp, + shared_dim, + DS4_N_EMBD, + g->batch_shared_mid, + n_tokens); DS4_METAL_PROFILE_FFN_STAGE("shared_down"); if (ok) { metal_graph_debug_dump_tensor("ffn_shexp", g->batch_shared_out, @@ -14177,6 +14308,7 @@ struct ds4_engine { float *directional_steering_dirs; float directional_steering_attn_scale; float directional_steering_ffn_scale; + ds4_mpp_mode mpp_mode; bool quality; bool metal_ready; bool mtp_ready; @@ -15418,6 +15550,15 @@ const char *ds4_backend_name(ds4_backend backend) { return "unknown"; } +const char *ds4_mpp_mode_name(ds4_mpp_mode mode) { + switch (mode) { + case DS4_MPP_AUTO: return "auto"; + case DS4_MPP_ON: return "on"; + case DS4_MPP_OFF: return "off"; + } + return "unknown"; +} + bool ds4_think_mode_enabled(ds4_think_mode mode) { return mode == DS4_THINK_HIGH || mode == DS4_THINK_MAX; } @@ -16954,6 +17095,7 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { e->mtp_model.fd = -1; e->backend = opt->backend; e->quality = opt->quality; + e->mpp_mode = opt->mpp_mode; e->mtp_draft_tokens = opt->mtp_draft_tokens > 0 ? opt->mtp_draft_tokens : 1; if (e->mtp_draft_tokens > 16) e->mtp_draft_tokens = 16; e->mtp_margin = opt->mtp_margin >= 0.0f ? opt->mtp_margin : 3.0f; @@ -17019,6 +17161,7 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { *out = NULL; return 1; } + ds4_gpu_set_mpp_mode(e->mpp_mode); ds4_gpu_set_quality(e->quality); (void)ds4_gpu_set_model_fd(e->model.fd); if (!ds4_gpu_set_model_map_range(e->model.map, @@ -17076,6 +17219,10 @@ void ds4_engine_summary(ds4_engine *e) { model_summary(&e->model); } +int ds4_engine_vocab_size(ds4_engine *e) { + return e ? e->vocab.n_vocab : 0; +} + void ds4_engine_close(ds4_engine *e) { if (!e) return; weights_free(&e->weights); @@ -17485,6 +17632,12 @@ int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out) { return 1; } +int ds4_session_copy_logits(ds4_session *s, float *out, int cap) { + if (!s || !out || cap < (int)DS4_N_VOCAB) return 0; + memcpy(out, s->logits, (size_t)DS4_N_VOCAB * sizeof(out[0])); + return (int)DS4_N_VOCAB; +} + static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, char *err, size_t errlen) { if (!s) return 1; diff --git a/ds4.h b/ds4.h index 950d8dca..c60105f7 100644 --- a/ds4.h +++ b/ds4.h @@ -20,6 +20,12 @@ typedef enum { DS4_BACKEND_CPU, } ds4_backend; +typedef enum { + DS4_MPP_AUTO = 0, + DS4_MPP_ON, + DS4_MPP_OFF, +} ds4_mpp_mode; + typedef enum { DS4_THINK_NONE, DS4_THINK_HIGH, @@ -67,6 +73,7 @@ typedef struct { float directional_steering_ffn; bool warm_weights; bool quality; + ds4_mpp_mode mpp_mode; } ds4_engine_options; typedef void (*ds4_token_emit_fn)(void *ud, int token); @@ -91,7 +98,9 @@ typedef struct { int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt); void ds4_engine_close(ds4_engine *e); void ds4_engine_summary(ds4_engine *e); +int ds4_engine_vocab_size(ds4_engine *e); const char *ds4_backend_name(ds4_backend backend); +const char *ds4_mpp_mode_name(ds4_mpp_mode mode); bool ds4_think_mode_enabled(ds4_think_mode mode); const char *ds4_think_mode_name(ds4_think_mode mode); const char *ds4_think_max_prefix(void); @@ -168,6 +177,7 @@ int ds4_session_argmax_excluding(ds4_session *s, int excluded_id); int ds4_session_sample(ds4_session *s, float temperature, int top_k, float top_p, float min_p, uint64_t *rng); int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k); int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out); +int ds4_session_copy_logits(ds4_session *s, float *out, int cap); int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen); int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token, int max_tokens, int eos_token, diff --git a/ds4_cli.c b/ds4_cli.c index bc70e659..0bfd71e7 100644 --- a/ds4_cli.c +++ b/ds4_cli.c @@ -102,7 +102,9 @@ static void usage(FILE *fp) { " -t, --threads N\n" " CPU helper threads for host-side or reference work.\n" " --quality\n" - " Prefer exact kernels where faster approximate paths exist; MTP uses strict verification.\n" + " Prefer exact kernels where faster approximate paths exist; disables Metal 4 MPP routes; MTP uses strict verification.\n" + " --mpp MODE\n" + " Metal 4 MPP policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" " --dir-steering-file FILE\n" " Load one f32 direction vector per layer for directional steering.\n" " --dir-steering-ffn F\n" @@ -240,6 +242,15 @@ static ds4_backend default_backend(void) { #endif } +static ds4_mpp_mode parse_mpp_mode(const char *s) { + if (!strcmp(s, "auto")) return DS4_MPP_AUTO; + if (!strcmp(s, "on")) return DS4_MPP_ON; + if (!strcmp(s, "off")) return DS4_MPP_OFF; + fprintf(stderr, "ds4: invalid MPP mode: %s\n", s); + fprintf(stderr, "ds4: valid MPP modes are: auto, on, off\n"); + exit(2); +} + static void log_context_memory(ds4_backend backend, int ctx_size) { ds4_context_memory m = ds4_context_memory_estimate(backend, ctx_size); fprintf(stderr, @@ -1244,6 +1255,8 @@ static cli_config parse_options(int argc, char **argv) { c.gen.seed = parse_u64(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--quality")) { c.engine.quality = true; + } else if (!strcmp(arg, "--mpp")) { + c.engine.mpp_mode = parse_mpp_mode(need_arg(&i, argc, argv, arg)); } else if (!strcmp(arg, "--dir-steering-file")) { c.engine.directional_steering_file = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--dir-steering-ffn")) { diff --git a/ds4_gpu.h b/ds4_gpu.h index 2b33b5ea..b000af9f 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -4,6 +4,8 @@ #include #include +#include "ds4.h" + /* ========================================================================= * GPU Tensor and Command Lifetime. * ========================================================================= @@ -41,6 +43,9 @@ int ds4_gpu_set_model_map_range(const void *model_map, uint64_t model_size, uint int ds4_gpu_cache_model_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, const char *label); int ds4_gpu_cache_q8_f16_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, uint64_t in_dim, uint64_t out_dim, const char *label); void ds4_gpu_set_quality(bool quality); +void ds4_gpu_set_mpp_mode(ds4_mpp_mode mode); +void ds4_gpu_set_mpp_compare_context(const char *module, uint32_t layer_index, uint32_t pos0); +void ds4_gpu_clear_mpp_compare_context(void); void ds4_gpu_print_memory_report(const char *label); /* ========================================================================= diff --git a/ds4_metal.m b/ds4_metal.m index 03a428b7..741dc515 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -172,6 +173,38 @@ static NSUInteger g_attn_out_group_ids_bytes; static int g_initialized; static int g_quality_mode; +static ds4_mpp_mode g_mpp_mode = DS4_MPP_AUTO; +static int g_mpp_q8_reported; +static int g_mpp_q8_partial_skip_reported; +static int g_mpp_f16_reported; +static int g_mpp_f16_pair_reported; +static int g_mpp_attn_out_reported; +static int g_mpp_moe_reported; +static int g_mpp_moe_ranges_reported; +static int g_mpp_invalid_env_reported; +static char g_mpp_compare_context[128]; + +#define DS4_METAL_MPP_COMPARE_PENDING_MAX 64 +#define DS4_METAL_MPP_COMPARE_DELTAS 5 + +typedef struct { + __strong id ref_buffer; + __strong id cand_buffer; + NSUInteger ref_offset; + NSUInteger cand_offset; + uint64_t elements; + uint64_t dim0; + uint64_t dim1; + uint64_t dim2; + char route[16]; + char label[128]; +} ds4_gpu_mpp_compare_item; + +static ds4_gpu_mpp_compare_item g_mpp_compare_pending[DS4_METAL_MPP_COMPARE_PENDING_MAX]; +static int g_mpp_compare_pending_count; +static int g_mpp_compare_done_count; +static int g_mpp_compare_stopped; +static int g_mpp_compare_limit_reported; static uint64_t ds4_gpu_system_memory_bytes(void) { uint64_t bytes = 0; @@ -283,12 +316,260 @@ static int ds4_gpu_wait_pending_command_buffers(const char *label) { return ok; } +static int ds4_gpu_mpp_compare_max(void) { + const char *env = getenv("DS4_METAL_MPP_COMPARE_MAX"); + if (!env || !env[0]) return 20; + char *end = NULL; + unsigned long v = strtoul(env, &end, 10); + if (end == env) return 20; + if (v > 1000000ul) v = 1000000ul; + return (int)v; +} + +static int ds4_gpu_mpp_compare_verbose(void) { + const char *env = getenv("DS4_METAL_MPP_COMPARE_VERBOSE"); + return env && env[0] && strcmp(env, "0") != 0 && + strcmp(env, "false") != 0 && strcmp(env, "off") != 0; +} + +static int ds4_gpu_mpp_compare_route_matches(const char *route) { + if (g_mpp_compare_stopped) return 0; + const char *want = getenv("DS4_METAL_MPP_COMPARE_ROUTE"); + if (!want || !want[0] || !route || !route[0]) return 0; + if (strcmp(want, "all") == 0) return 1; + return strcmp(want, route) == 0; +} + +static const char *ds4_gpu_mpp_compare_label(const char *fallback, + char *buf, + size_t buflen) { + if (g_mpp_compare_context[0]) return g_mpp_compare_context; + snprintf(buf, buflen, "%s", fallback && fallback[0] ? fallback : "unknown"); + return buf; +} + +static void ds4_gpu_mpp_compare_note_delta( + uint64_t *idx, + float *ref_vals, + float *cand_vals, + float *abs_vals, + uint64_t id, + float ref, + float cand) { + const float abs_delta = fabsf(cand - ref); + for (int i = 0; i < DS4_METAL_MPP_COMPARE_DELTAS; i++) { + if (idx[i] == UINT64_MAX || abs_delta > abs_vals[i]) { + for (int j = DS4_METAL_MPP_COMPARE_DELTAS - 1; j > i; j--) { + idx[j] = idx[j - 1]; + ref_vals[j] = ref_vals[j - 1]; + cand_vals[j] = cand_vals[j - 1]; + abs_vals[j] = abs_vals[j - 1]; + } + idx[i] = id; + ref_vals[i] = ref; + cand_vals[i] = cand; + abs_vals[i] = abs_delta; + return; + } + } +} + +static void ds4_gpu_mpp_compare_clear_pending(void) { + for (int i = 0; i < g_mpp_compare_pending_count; i++) { + g_mpp_compare_pending[i].ref_buffer = nil; + g_mpp_compare_pending[i].cand_buffer = nil; + g_mpp_compare_pending[i].elements = 0; + g_mpp_compare_pending[i].route[0] = '\0'; + g_mpp_compare_pending[i].label[0] = '\0'; + } + g_mpp_compare_pending_count = 0; +} + +static void ds4_gpu_mpp_compare_reset(void) { + ds4_gpu_mpp_compare_clear_pending(); + g_mpp_compare_done_count = 0; + g_mpp_compare_stopped = 0; + g_mpp_compare_limit_reported = 0; +} + +static void ds4_gpu_mpp_compare_drain(const char *finish_label) { + (void)finish_label; + const int max_reports = ds4_gpu_mpp_compare_max(); + for (int i = 0; i < g_mpp_compare_pending_count; i++) { + ds4_gpu_mpp_compare_item *item = &g_mpp_compare_pending[i]; + if (g_mpp_compare_stopped || g_mpp_compare_done_count >= max_reports || + !item->ref_buffer || !item->cand_buffer || item->elements == 0) { + continue; + } + + const float *ref = (const float *)((const uint8_t *)[item->ref_buffer contents] + item->ref_offset); + const float *cand = (const float *)((const uint8_t *)[item->cand_buffer contents] + item->cand_offset); + double sumsq = 0.0; + float max_abs = 0.0f; + uint64_t max_index = 0; + int nonfinite = 0; + uint64_t delta_idx[DS4_METAL_MPP_COMPARE_DELTAS]; + float delta_ref[DS4_METAL_MPP_COMPARE_DELTAS]; + float delta_cand[DS4_METAL_MPP_COMPARE_DELTAS]; + float delta_abs[DS4_METAL_MPP_COMPARE_DELTAS]; + for (int j = 0; j < DS4_METAL_MPP_COMPARE_DELTAS; j++) { + delta_idx[j] = UINT64_MAX; + delta_ref[j] = 0.0f; + delta_cand[j] = 0.0f; + delta_abs[j] = 0.0f; + } + + for (uint64_t j = 0; j < item->elements; j++) { + if (!isfinite(ref[j]) || !isfinite(cand[j])) { + nonfinite++; + continue; + } + const float delta = cand[j] - ref[j]; + const float abs_delta = fabsf(delta); + sumsq += (double)delta * (double)delta; + if (abs_delta > max_abs) { + max_abs = abs_delta; + max_index = j; + } + ds4_gpu_mpp_compare_note_delta(delta_idx, delta_ref, delta_cand, delta_abs, + j, ref[j], cand[j]); + } + + const float rms = (float)sqrt(sumsq / (double)item->elements); + const int exceeds_target = (nonfinite != 0 || max_abs > 1.0e-3f || rms > 1.0e-4f); + if (ds4_gpu_mpp_compare_verbose() || exceeds_target) { + fprintf(stderr, + "ds4: Metal MPP compare route=%s module=%s shape=%llux%llux%llu max_abs=%g rms=%g nonfinite=%d max_index=%llu\n", + item->route, + item->label, + (unsigned long long)item->dim0, + (unsigned long long)item->dim1, + (unsigned long long)item->dim2, + max_abs, + rms, + nonfinite, + (unsigned long long)max_index); + fprintf(stderr, "ds4: Metal MPP compare route=%s module=%s largest deltas:", + item->route, item->label); + for (int j = 0; j < DS4_METAL_MPP_COMPARE_DELTAS && delta_idx[j] != UINT64_MAX; j++) { + fprintf(stderr, " idx=%llu ref=%g cand=%g abs=%g", + (unsigned long long)delta_idx[j], + delta_ref[j], + delta_cand[j], + delta_abs[j]); + } + fputc('\n', stderr); + } + + g_mpp_compare_done_count++; + if (exceeds_target) { + fprintf(stderr, + "ds4: Metal MPP compare route=%s module=%s exceeded target max_abs<=0.001 rms<=0.0001; stopping comparisons\n", + item->route, + item->label); + g_mpp_compare_stopped = 1; + } + } + if (!g_mpp_compare_stopped && !g_mpp_compare_limit_reported && + g_mpp_compare_done_count >= max_reports) { + fprintf(stderr, + "ds4: Metal MPP compare reached DS4_METAL_MPP_COMPARE_MAX=%d without a target breach\n", + max_reports); + g_mpp_compare_limit_reported = 1; + } + ds4_gpu_mpp_compare_clear_pending(); +} + +static void ds4_gpu_mpp_compare_register( + const char *route, + const char *fallback_label, + const ds4_gpu_tensor *ref, + const ds4_gpu_tensor *cand, + uint64_t elements, + uint64_t dim0, + uint64_t dim1, + uint64_t dim2) { + if (!ds4_gpu_mpp_compare_route_matches(route)) return; + if (g_mpp_compare_done_count + g_mpp_compare_pending_count >= ds4_gpu_mpp_compare_max()) return; + if (g_mpp_compare_pending_count >= DS4_METAL_MPP_COMPARE_PENDING_MAX) return; + id ref_buffer = ds4_gpu_tensor_buffer(ref); + id cand_buffer = ds4_gpu_tensor_buffer(cand); + if (!ref_buffer || !cand_buffer || elements == 0) return; + + ds4_gpu_mpp_compare_item *item = &g_mpp_compare_pending[g_mpp_compare_pending_count++]; + item->ref_buffer = nil; + item->cand_buffer = nil; + item->ref_offset = 0; + item->cand_offset = 0; + item->elements = 0; + item->dim0 = 0; + item->dim1 = 0; + item->dim2 = 0; + item->route[0] = '\0'; + item->label[0] = '\0'; + item->ref_buffer = ref_buffer; + item->cand_buffer = cand_buffer; + item->ref_offset = ds4_gpu_tensor_offset(ref); + item->cand_offset = ds4_gpu_tensor_offset(cand); + item->elements = elements; + item->dim0 = dim0; + item->dim1 = dim1; + item->dim2 = dim2; + snprintf(item->route, sizeof(item->route), "%s", route); + char label_buf[128]; + snprintf(item->label, sizeof(item->label), "%s", + ds4_gpu_mpp_compare_label(fallback_label, label_buf, sizeof(label_buf))); +} + +static ds4_gpu_tensor *ds4_gpu_mpp_compare_make_buffer_view( + id buffer, + NSUInteger offset, + uint64_t bytes) { + if (!buffer || bytes > (uint64_t)NSUIntegerMax) return NULL; + DS4MetalTensor *view = [DS4MetalTensor new]; + view.buffer = buffer; + view.offset = (uint64_t)offset; + view.bytes = bytes; + view.owner = 0; + return (__bridge_retained ds4_gpu_tensor *)view; +} + +static ds4_gpu_tensor *ds4_gpu_mpp_compare_snapshot_buffer( + id buffer, + NSUInteger offset, + uint64_t bytes) { + ds4_gpu_tensor *view = ds4_gpu_mpp_compare_make_buffer_view(buffer, offset, bytes); + ds4_gpu_tensor *snapshot = ds4_gpu_tensor_alloc(bytes); + if (!view || !snapshot) { + ds4_gpu_tensor_free(view); + ds4_gpu_tensor_free(snapshot); + return NULL; + } + + int ok = 0; + if (g_batch_cb) { + ok = ds4_gpu_tensor_copy(snapshot, 0, view, 0, bytes); + } else { + memcpy(ds4_gpu_tensor_contents(snapshot), + (const uint8_t *)[buffer contents] + offset, + (size_t)bytes); + ok = 1; + } + ds4_gpu_tensor_free(view); + if (!ok) { + ds4_gpu_tensor_free(snapshot); + return NULL; + } + return snapshot; +} + static int ds4_gpu_finish_command_buffer(id cb, int owned, const char *label) { if (!owned) return 1; [cb commit]; int ok = ds4_gpu_wait_pending_command_buffers(label); if (!ds4_gpu_wait_command_buffer(cb, label)) ok = 0; + if (ok) ds4_gpu_mpp_compare_drain(label); [g_transient_buffers removeAllObjects]; return ok; } @@ -683,61 +964,369 @@ static int ds4_gpu_use_compressor_pair_nr4(void) { static int ds4_gpu_device_name_contains(const char *needle); static int ds4_gpu_mpp_q8_0_default_target(void) { - return ds4_gpu_device_name_contains("M5") || - ds4_gpu_device_name_contains("M6") || - ds4_gpu_device_name_contains("A19") || - ds4_gpu_device_name_contains("A20"); + return 1; +} + +static int ds4_gpu_env_value_eq(const char *v, size_t n, const char *literal) { + size_t m = strlen(literal); + if (n != m) return 0; + for (size_t i = 0; i < n; i++) { + if (tolower((unsigned char)v[i]) != tolower((unsigned char)literal[i])) return 0; + } + return 1; +} + +static int ds4_gpu_env_bool(const char *name) { + const char *v = getenv(name); + if (!v) return -1; + + while (isspace((unsigned char)*v)) v++; + size_t n = strlen(v); + while (n > 0 && isspace((unsigned char)v[n - 1])) n--; + if (n == 0) return 1; + + if (ds4_gpu_env_value_eq(v, n, "1") || + ds4_gpu_env_value_eq(v, n, "true") || + ds4_gpu_env_value_eq(v, n, "yes") || + ds4_gpu_env_value_eq(v, n, "on")) { + return 1; + } + if (ds4_gpu_env_value_eq(v, n, "0") || + ds4_gpu_env_value_eq(v, n, "false") || + ds4_gpu_env_value_eq(v, n, "no") || + ds4_gpu_env_value_eq(v, n, "off")) { + return 0; + } + + if (!g_mpp_invalid_env_reported) { + fprintf(stderr, + "ds4: invalid Metal MPP boolean environment value %s=%.*s; treating presence as enabled\n", + name, (int)n, v); + g_mpp_invalid_env_reported = 1; + } + return 1; +} + +typedef enum { + DS4_METAL_MPP_GLOBAL_OFF, + DS4_METAL_MPP_GLOBAL_AUTO, + DS4_METAL_MPP_GLOBAL_ON, +} ds4_gpu_mpp_global_policy; + +static ds4_gpu_mpp_global_policy ds4_gpu_mpp_global_policy_mode(void) { + if (!g_metal4_tensor_api_enabled || g_quality_mode) return DS4_METAL_MPP_GLOBAL_OFF; + if (g_mpp_mode == DS4_MPP_OFF) return DS4_METAL_MPP_GLOBAL_OFF; + if (g_mpp_mode == DS4_MPP_ON) return DS4_METAL_MPP_GLOBAL_ON; + + const int disabled = ds4_gpu_env_bool("DS4_METAL_MPP_DISABLE"); + if (disabled > 0) return DS4_METAL_MPP_GLOBAL_OFF; + + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_ENABLE"); + if (enabled >= 0) return enabled ? DS4_METAL_MPP_GLOBAL_ON : DS4_METAL_MPP_GLOBAL_OFF; + + return DS4_METAL_MPP_GLOBAL_AUTO; +} + +static int ds4_gpu_mpp_route_switch(const char *enable_env, const char *disable_env) { + const int disabled = ds4_gpu_env_bool(disable_env); + if (disabled > 0) return 0; + + const int enabled = ds4_gpu_env_bool(enable_env); + if (enabled >= 0) return enabled ? 1 : 0; + + return -1; +} + +static int ds4_gpu_mpp_route_enabled( + int default_target, + const char *enable_env, + const char *disable_env) { + const ds4_gpu_mpp_global_policy policy = ds4_gpu_mpp_global_policy_mode(); + if (policy == DS4_METAL_MPP_GLOBAL_OFF) return 0; + + const int route = ds4_gpu_mpp_route_switch(enable_env, disable_env); + if (route >= 0) return route; + + if (policy == DS4_METAL_MPP_GLOBAL_ON) return 1; + return default_target; +} + +static int ds4_gpu_mpp_fast_profile(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_FAST") > 0; +} + +static const char *ds4_gpu_mpp_enabled_reason(void) { + if (g_mpp_mode == DS4_MPP_ON) return " by --mpp on"; + if (ds4_gpu_mpp_fast_profile()) return " by DS4_METAL_MPP_FAST"; + if (ds4_gpu_env_bool("DS4_METAL_MPP_ENABLE") > 0) return " by DS4_METAL_MPP_ENABLE"; + return " by default"; } static int ds4_gpu_mpp_q8_0_policy_enabled(void) { - if (!g_metal4_tensor_api_enabled) return 0; - if (getenv("DS4_METAL_MPP_DISABLE") != NULL) return 0; - if (getenv("DS4_METAL_MPP_ENABLE") != NULL) return 1; - return ds4_gpu_mpp_q8_0_default_target(); + return ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_q8_0_default_target(), + "DS4_METAL_MPP_Q8_0_ENABLE", + "DS4_METAL_MPP_Q8_0_DISABLE"); } static int ds4_gpu_use_mpp_q8_0_matmul(void) { - static int initialized; - static int enabled; - if (!initialized) { - enabled = ds4_gpu_mpp_q8_0_policy_enabled(); - if (enabled) { - const int forced = getenv("DS4_METAL_MPP_ENABLE") != NULL; - fprintf(stderr, "ds4: Metal MPP Q8_0 prefill matmul enabled%s\n", - forced ? " by environment" : " by default"); - } - initialized = 1; + const int enabled = ds4_gpu_mpp_q8_0_policy_enabled(); + if (enabled && !g_mpp_q8_reported) { + fprintf(stderr, "ds4: Metal MPP Q8_0 prefill matmul enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_q8_reported = 1; } return enabled; } -static int ds4_gpu_use_mpp_f16_compressor_matmul(void) { - static int initialized; - static int enabled; - if (!initialized) { - enabled = ds4_gpu_mpp_q8_0_policy_enabled() && - getenv("DS4_METAL_MPP_F16_DISABLE") == NULL; - if (enabled) { - const int forced = getenv("DS4_METAL_MPP_ENABLE") != NULL; - fprintf(stderr, "ds4: Metal MPP F16 compressor prefill matmul enabled%s\n", - forced ? " by environment" : " by default"); +static int ds4_gpu_mpp_q8_0_partial_tiles_enabled(void) { + if (ds4_gpu_mpp_fast_profile()) return 1; + return ds4_gpu_env_bool("DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE") > 0; +} + +static uint32_t ds4_gpu_mpp_tile_n_env(const char *name) { + const char *env = getenv(name); + if (!env || !env[0]) return 32; + char *end = NULL; + long v = strtol(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end && *end == '\0' && v == 64) return 64; + if (end && *end == '\0' && v == 32) return 32; + fprintf(stderr, + "ds4: invalid %s=%s; expected 32 or 64, using 32\n", + name, env); + return 32; +} + +static uint32_t ds4_gpu_mpp_q8_0_tile_n(void) { + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_Q8_0_TILE_N"); +} + +static uint32_t ds4_gpu_mpp_attn_out_tile_n(void) { + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_ATTN_OUT_TILE_N"); +} + +static uint32_t ds4_gpu_mpp_moe_tile_n(void) { + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_MOE_TILE_N"); +} + +static int ds4_gpu_mpp_moe_fast_layout(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_MOE_FAST_LAYOUT") > 0; +} + +static int ds4_gpu_mpp_moe_pair_gate_up(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_MOE_PAIR_GATE_UP") > 0; +} + +static int ds4_gpu_mpp_direct_rhs(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_DIRECT_RHS") > 0; +} + +static int ds4_gpu_mpp_q8_0_direct_rhs(void) { + return ds4_gpu_mpp_direct_rhs() || + ds4_gpu_env_bool("DS4_METAL_MPP_Q8_0_DIRECT_RHS") > 0; +} + +static int ds4_gpu_mpp_f16_direct_rhs(void) { + return ds4_gpu_mpp_direct_rhs() || + ds4_gpu_env_bool("DS4_METAL_MPP_F16_DIRECT_RHS") > 0; +} + +static int ds4_gpu_mpp_f16_wide_matmul(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_F16_WIDE") > 0; +} + +static int ds4_gpu_mpp_f16_pair_matmul(void) { + return ds4_gpu_env_bool("DS4_METAL_MPP_F16_PAIR") > 0; +} + +static int ds4_gpu_mpp_attn_out_direct_rhs(void) { + return ds4_gpu_mpp_direct_rhs() || + ds4_gpu_env_bool("DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS") > 0; +} + +static int ds4_gpu_mpp_layer_env(const char *name, int fallback) { + const char *env = getenv(name); + if (!env || !env[0]) return fallback; + char *end = NULL; + long v = strtol(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end && *end == '\0' && v >= 0 && v <= 255) return (int)v; + fprintf(stderr, + "ds4: invalid %s=%s; expected layer index 0..255, using %d\n", + name, env, fallback); + return fallback; +} + +static int ds4_gpu_mpp_context_layer(void) { + if (!g_mpp_compare_context[0]) return -1; + int layer = -1; + if (sscanf(g_mpp_compare_context, "layer=%d", &layer) == 1) return layer; + return -1; +} + +static int ds4_gpu_mpp_late_safe_context_range(int first_layer) { + const int layer = ds4_gpu_mpp_context_layer(); + return layer >= first_layer && layer <= 42; +} + +static int ds4_gpu_mpp_q8_0_late_safe_context(void) { + const int layer = ds4_gpu_mpp_context_layer(); + if (layer >= 38 && layer <= 42) return 1; + if (layer >= 32 && layer <= 37 && + strstr(g_mpp_compare_context, "attn_q_b") != NULL) { + return 1; + } + return 0; +} + +static int ds4_gpu_mpp_attn_out_late_safe_context(void) { + return ds4_gpu_mpp_late_safe_context_range(32); +} + +static int ds4_gpu_mpp_layer_expr_matches(const char *layer_expr) { + if (!layer_expr || !*layer_expr) return 0; + const int layer = ds4_gpu_mpp_context_layer(); + char *parse_end = NULL; + long first = strtol(layer_expr, &parse_end, 10); + while (parse_end && isspace((unsigned char)*parse_end)) parse_end++; + if (!parse_end || parse_end == layer_expr || + first < 0 || first > 255 || + !(parse_end[0] == '\0' || + (parse_end[0] == '-' && parse_end[1] != '\0') || + (parse_end[0] == '.' && parse_end[1] == '.' && parse_end[2] != '\0'))) { + return 0; + } + + long last = first; + if (parse_end[0] == '-') { + const char *range_end = parse_end + 1; + while (isspace((unsigned char)*range_end)) range_end++; + char *end2 = NULL; + last = strtol(range_end, &end2, 10); + while (end2 && isspace((unsigned char)*end2)) end2++; + if (!end2 || end2 == range_end || *end2 != '\0') return 0; + } else if (parse_end[0] == '.') { + const char *range_end = parse_end + 2; + while (isspace((unsigned char)*range_end)) range_end++; + char *end2 = NULL; + last = strtol(range_end, &end2, 10); + while (end2 && isspace((unsigned char)*end2)) end2++; + if (!end2 || end2 == range_end || *end2 != '\0') return 0; + } + if (last < first || last < 0 || last > 255) return 0; + return layer >= first && layer <= last; +} + +static int ds4_gpu_mpp_context_matches_filter( + const char *env_name, + int default_match, + int late_safe_match) { + const char *filter = getenv(env_name); + if (!filter || !filter[0]) return default_match; + if (!g_mpp_compare_context[0]) return 0; + + const char *p = filter; + while (*p) { + while (*p == ',' || isspace((unsigned char)*p)) p++; + const char *start = p; + while (*p && *p != ',') p++; + const char *end = p; + while (end > start && isspace((unsigned char)end[-1])) end--; + if (end > start) { + char token[64]; + size_t n = (size_t)(end - start); + if (n >= sizeof(token)) n = sizeof(token) - 1u; + memcpy(token, start, n); + token[n] = '\0'; + if (ds4_gpu_env_value_eq(token, n, "all")) return 1; + if (ds4_gpu_env_value_eq(token, n, "none")) return 0; + if (ds4_gpu_env_value_eq(token, n, "late_safe")) return late_safe_match; + char *at = strchr(token, '@'); + if (at) { + *at = '\0'; + const char *module = token; + const char *expr = at + 1; + if (strncmp(expr, "layer=", 6) == 0) { + expr += 6; + } else if (strncmp(expr, "layer:", 6) == 0) { + expr += 6; + } else { + continue; + } + if (*module && + strstr(g_mpp_compare_context, module) != NULL && + ds4_gpu_mpp_layer_expr_matches(expr)) { + return 1; + } + continue; + } + const char *layer_expr = NULL; + if (strncmp(token, "layer=", 6) == 0) { + layer_expr = token + 6; + } else if (strncmp(token, "layer:", 6) == 0) { + layer_expr = token + 6; + } + if (layer_expr && *layer_expr) { + if (ds4_gpu_mpp_layer_expr_matches(layer_expr)) return 1; + continue; + } + if (strstr(g_mpp_compare_context, token) != NULL) return 1; } - initialized = 1; + } + return 0; +} + +static int ds4_gpu_mpp_q8_0_context_matches_filter(void) { + const int default_match = ds4_gpu_mpp_fast_profile() + ? 1 + : ds4_gpu_mpp_q8_0_late_safe_context(); + return ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_Q8_0_FILTER", + default_match, + ds4_gpu_mpp_q8_0_late_safe_context()); +} + +static int ds4_gpu_can_use_mpp_q8_0_matmul(uint64_t n_tok) { + if (n_tok <= 8) return 0; + if (!ds4_gpu_use_mpp_q8_0_matmul()) return 0; + if (!ds4_gpu_mpp_q8_0_context_matches_filter()) return 0; + if ((n_tok % 32u) == 0 || ds4_gpu_mpp_q8_0_partial_tiles_enabled()) return 1; + + if (!g_mpp_q8_partial_skip_reported) { + fprintf(stderr, + "ds4: Metal MPP Q8_0 prefill matmul skipping partial token tiles; " + "set DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=1 to test them\n"); + g_mpp_q8_partial_skip_reported = 1; + } + return 0; +} + +static int ds4_gpu_use_mpp_f16_compressor_matmul(void) { + const int enabled = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_q8_0_default_target(), + "DS4_METAL_MPP_F16_ENABLE", + "DS4_METAL_MPP_F16_DISABLE"); + if (enabled && !g_mpp_f16_reported) { + fprintf(stderr, "ds4: Metal MPP F16 compressor prefill matmul enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_f16_reported = 1; } return enabled; } static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { - static int initialized; - static int enabled; - if (!initialized) { - enabled = g_metal4_tensor_api_enabled && - getenv("DS4_METAL_MPP_DISABLE") == NULL && - getenv("DS4_METAL_MPP_ATTN_OUT_DISABLE") == NULL; - if (enabled) { - fprintf(stderr, "ds4: Metal MPP attention-output low projection enabled by default\n"); - } - initialized = 1; + const int default_match = ds4_gpu_mpp_fast_profile() + ? 1 + : ds4_gpu_mpp_attn_out_late_safe_context(); + const int enabled = + ds4_gpu_mpp_route_enabled(1, + "DS4_METAL_MPP_ATTN_OUT_ENABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE") && + ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_ATTN_OUT_FILTER", + default_match, + ds4_gpu_mpp_attn_out_late_safe_context()); + if (enabled && !g_mpp_attn_out_reported) { + fprintf(stderr, "ds4: Metal MPP attention-output low projection enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_attn_out_reported = 1; } return enabled; } @@ -747,54 +1336,137 @@ static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { DS4_METAL_MOE_MPP_UP = 1 << 1, DS4_METAL_MOE_MPP_DOWN = 1 << 2, - DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 13, - DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 13, - DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 2, + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 28, + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 30, + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 28, + DS4_METAL_MOE_MPP_FAST_GATE_LAYER = 13, + DS4_METAL_MOE_MPP_FAST_UP_LAYER = 13, + DS4_METAL_MOE_MPP_FAST_DOWN_LAYER = 2, }; static int ds4_gpu_mpp_routed_moe_default_target(void) { - return ds4_gpu_device_name_contains("M5"); + return 1; } static int ds4_gpu_mpp_routed_moe_default_policy(void) { - return g_metal4_tensor_api_enabled && - getenv("DS4_METAL_MPP_DISABLE") == NULL && - ds4_gpu_mpp_routed_moe_default_target(); + const ds4_gpu_mpp_global_policy policy = ds4_gpu_mpp_global_policy_mode(); + if (policy == DS4_METAL_MPP_GLOBAL_OFF) return 0; + if (policy == DS4_METAL_MPP_GLOBAL_ON) return 1; + + const int group = ds4_gpu_mpp_route_switch("DS4_METAL_MPP_MOE_ENABLE", + "DS4_METAL_MPP_MOE_DISABLE"); + if (group >= 0) return group; + + return ds4_gpu_mpp_routed_moe_default_target(); +} + +static int ds4_gpu_mpp_moe_route_enabled(const char *enable_env, const char *disable_env) { + const ds4_gpu_mpp_global_policy policy = ds4_gpu_mpp_global_policy_mode(); + if (policy == DS4_METAL_MPP_GLOBAL_OFF) return 0; + + const int group = ds4_gpu_mpp_route_switch("DS4_METAL_MPP_MOE_ENABLE", + "DS4_METAL_MPP_MOE_DISABLE"); + if (group == 0) return 0; + + const int route = ds4_gpu_mpp_route_switch(enable_env, disable_env); + if (route >= 0) return route; + + if (group == 1 || policy == DS4_METAL_MPP_GLOBAL_ON) return 1; + return ds4_gpu_mpp_routed_moe_default_target(); } static int ds4_gpu_mpp_routed_moe_stage_mask(void) { - static int initialized; - static int mask; - if (!initialized) { - if (ds4_gpu_mpp_routed_moe_default_policy()) { - mask = DS4_METAL_MOE_MPP_GATE | DS4_METAL_MOE_MPP_UP | DS4_METAL_MOE_MPP_DOWN; - } - if (mask) { - fprintf(stderr, "ds4: Metal MPP routed MoE projections enabled by default for staged prefill layers\n"); - } - initialized = 1; + int mask = 0; + if (ds4_gpu_mpp_moe_route_enabled("DS4_METAL_MPP_MOE_GATE_ENABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE")) { + mask |= DS4_METAL_MOE_MPP_GATE; + } + if (ds4_gpu_mpp_moe_route_enabled("DS4_METAL_MPP_MOE_UP_ENABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE")) { + mask |= DS4_METAL_MOE_MPP_UP; + } + if (ds4_gpu_mpp_moe_route_enabled("DS4_METAL_MPP_MOE_DOWN_ENABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE")) { + mask |= DS4_METAL_MOE_MPP_DOWN; + } + if (mask && !g_mpp_moe_reported) { + fprintf(stderr, "ds4: Metal MPP routed MoE projections enabled%s\n", + ds4_gpu_mpp_enabled_reason()); + g_mpp_moe_reported = 1; } return mask; } +static int ds4_gpu_mpp_moe_late_safe_context(int first_layer) { + return ds4_gpu_mpp_late_safe_context_range(first_layer); +} + +static int ds4_gpu_mpp_moe_context_matches_filter(const char *route_filter_env, + int first_layer) { + return ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_MOE_FILTER", + 1, + ds4_gpu_mpp_moe_late_safe_context(first_layer)) && + ds4_gpu_mpp_context_matches_filter(route_filter_env, + 1, + ds4_gpu_mpp_moe_late_safe_context(first_layer)); +} + +static int ds4_gpu_mpp_moe_start_layer(const char *route_env, int fallback) { + const int common = ds4_gpu_mpp_layer_env("DS4_METAL_MPP_MOE_START_LAYER", fallback); + return ds4_gpu_mpp_layer_env(route_env, common); +} + static int ds4_gpu_mpp_routed_moe_mask_for_layer(uint32_t layer_index) { const int requested_mask = ds4_gpu_mpp_routed_moe_stage_mask(); if (!requested_mask) return 0; if (ds4_gpu_mpp_routed_moe_default_policy()) { - static int initialized; - if (!initialized) { + const int fast_profile = ds4_gpu_mpp_fast_profile(); + const int down_fallback = fast_profile ? + DS4_METAL_MOE_MPP_FAST_DOWN_LAYER : + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER; + const int up_fallback = fast_profile ? + DS4_METAL_MOE_MPP_FAST_UP_LAYER : + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER; + const int gate_fallback = fast_profile ? + DS4_METAL_MOE_MPP_FAST_GATE_LAYER : + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER; + const int down_start = ds4_gpu_mpp_moe_start_layer( + "DS4_METAL_MPP_MOE_DOWN_START_LAYER", + down_fallback); + const int up_start = ds4_gpu_mpp_moe_start_layer( + "DS4_METAL_MPP_MOE_UP_START_LAYER", + up_fallback); + const int gate_start = ds4_gpu_mpp_moe_start_layer( + "DS4_METAL_MPP_MOE_GATE_START_LAYER", + gate_fallback); + if (!g_mpp_moe_ranges_reported) { fprintf(stderr, "ds4: Metal MPP routed MoE default ranges down=%d..end up=%d..end gate=%d..end\n", - DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER, - DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER, - DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER); - initialized = 1; + down_start, + up_start, + gate_start); + g_mpp_moe_ranges_reported = 1; } int mask = 0; - if (layer_index >= DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER) mask |= DS4_METAL_MOE_MPP_DOWN; - if (layer_index >= DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER) mask |= DS4_METAL_MOE_MPP_UP; - if (layer_index >= DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER) mask |= DS4_METAL_MOE_MPP_GATE; + if ((int)layer_index >= down_start) mask |= DS4_METAL_MOE_MPP_DOWN; + if ((int)layer_index >= up_start) mask |= DS4_METAL_MOE_MPP_UP; + if ((int)layer_index >= gate_start) mask |= DS4_METAL_MOE_MPP_GATE; + if ((mask & DS4_METAL_MOE_MPP_DOWN) && + !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_DOWN_FILTER", + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER)) { + mask &= ~DS4_METAL_MOE_MPP_DOWN; + } + if ((mask & DS4_METAL_MOE_MPP_UP) && + !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_UP_FILTER", + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER)) { + mask &= ~DS4_METAL_MOE_MPP_UP; + } + if ((mask & DS4_METAL_MOE_MPP_GATE) && + !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_GATE_FILTER", + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER)) { + mask &= ~DS4_METAL_MOE_MPP_GATE; + } return mask & requested_mask; } @@ -1367,10 +2039,27 @@ void ds4_gpu_print_memory_report(const char *label) { g_metal4_tensor_api_enabled ? "enabled" : (g_metal4_tensor_api_compile_supported ? "available" : "disabled"), g_metal4_m5_neural_accelerators_hint ? "likely" : "not detected"); + const int mpp_q8 = ds4_gpu_mpp_q8_0_policy_enabled(); + const int mpp_f16 = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_q8_0_default_target(), + "DS4_METAL_MPP_F16_ENABLE", + "DS4_METAL_MPP_F16_DISABLE"); + const int mpp_attn_out = ds4_gpu_mpp_route_enabled(0, + "DS4_METAL_MPP_ATTN_OUT_ENABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE"); + const int mpp_moe = ds4_gpu_mpp_routed_moe_stage_mask(); fprintf(stderr, - "ds4: MPP Q8_0 prefill %s%s\n", - ds4_gpu_mpp_q8_0_policy_enabled() ? "enabled" : "disabled", - getenv("DS4_METAL_MPP_DISABLE") != NULL ? " (disabled by DS4_METAL_MPP_DISABLE)" : ""); + "ds4: MPP policy %s%s%s\n", + ds4_mpp_mode_name(g_mpp_mode), + g_quality_mode ? " (disabled by --quality)" : "", + !g_metal4_tensor_api_enabled ? " (tensor API unavailable)" : ""); + fprintf(stderr, + "ds4: MPP routes q8_0=%s f16_compressor=%s attn_out=%s moe_gate=%s moe_up=%s moe_down=%s\n", + mpp_q8 ? "on" : "off", + mpp_f16 ? "on" : "off", + mpp_attn_out ? "on" : "off", + (mpp_moe & DS4_METAL_MOE_MPP_GATE) ? "on" : "off", + (mpp_moe & DS4_METAL_MOE_MPP_UP) ? "on" : "off", + (mpp_moe & DS4_METAL_MOE_MPP_DOWN) ? "on" : "off"); fprintf(stderr, "ds4: scratch %.2f MiB (flash mask %.2f, pad %.2f, tmp %.2f, blk %.2f, ring %.2f, kv %.2f, compressor %.2f, router %.2f, indexer %.2f, moe %.2f, f16 %.2f, raw-store %.2f)\n", ds4_gpu_mib(scratch), @@ -1400,8 +2089,47 @@ void ds4_gpu_print_memory_report(const char *label) { ds4_gpu_mib((uint64_t)g_raw_store_round_bytes)); } +static void ds4_gpu_mpp_reset_reports(void) { + g_mpp_q8_reported = 0; + g_mpp_q8_partial_skip_reported = 0; + g_mpp_f16_reported = 0; + g_mpp_f16_pair_reported = 0; + g_mpp_attn_out_reported = 0; + g_mpp_moe_reported = 0; + g_mpp_moe_ranges_reported = 0; +} + void ds4_gpu_set_quality(bool quality) { - g_quality_mode = quality ? 1 : 0; + const int next = quality ? 1 : 0; + if (g_quality_mode != next) { + ds4_gpu_mpp_reset_reports(); + ds4_gpu_mpp_compare_reset(); + } + g_quality_mode = next; +} + +void ds4_gpu_set_mpp_mode(ds4_mpp_mode mode) { + if (mode != DS4_MPP_AUTO && mode != DS4_MPP_ON && mode != DS4_MPP_OFF) { + mode = DS4_MPP_AUTO; + } + if (g_mpp_mode != mode) { + ds4_gpu_mpp_reset_reports(); + ds4_gpu_mpp_compare_reset(); + } + g_mpp_mode = mode; +} + +void ds4_gpu_set_mpp_compare_context(const char *module, uint32_t layer_index, uint32_t pos0) { + if (!module || !module[0]) { + g_mpp_compare_context[0] = '\0'; + return; + } + snprintf(g_mpp_compare_context, sizeof(g_mpp_compare_context), + "layer=%u pos=%u %s", layer_index, pos0, module); +} + +void ds4_gpu_clear_mpp_compare_context(void) { + g_mpp_compare_context[0] = '\0'; } static id ds4_gpu_wrap_model_range( @@ -2528,6 +3256,17 @@ static int ds4_gpu_encode_mul_mm_id_mapped( NSUInteger src1_off, id dst, NSUInteger dst_off); +static int ds4_gpu_encode_mul_mm_id_mapped_tile( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off, + uint32_t tile_n); typedef struct { int32_t ne11; @@ -4245,6 +4984,7 @@ int ds4_gpu_synchronize(void) { if (g_batch_cb) return ds4_gpu_end_commands(); if ([g_pending_cbs count] != 0) { int ok = ds4_gpu_wait_pending_command_buffers("synchronize"); + if (ok) ds4_gpu_mpp_compare_drain("synchronize"); [g_transient_buffers removeAllObjects]; return ok; } @@ -4399,6 +5139,8 @@ void ds4_gpu_cleanup(void) { g_queue = nil; g_device = nil; g_initialized = 0; + ds4_gpu_mpp_reset_reports(); + ds4_gpu_mpp_compare_reset(); } } @@ -5220,7 +5962,7 @@ int ds4_gpu_dsv4_topk_mask_tensor( return 1; } -int ds4_gpu_matmul_q8_0_tensor( +static int ds4_gpu_matmul_q8_0_legacy_tensor( ds4_gpu_tensor *out, const void *model_map, uint64_t model_size, @@ -5235,14 +5977,6 @@ int ds4_gpu_matmul_q8_0_tensor( return 0; } - if (n_tok > 8 && ds4_gpu_use_mpp_q8_0_matmul()) { - if (ds4_gpu_matmul_q8_0_mpp_tensor(out, model_map, model_size, weight_offset, - in_dim, out_dim, x, n_tok)) { - return 1; - } - ds4_gpu_warn_mpp_fallback(); - } - @autoreleasepool { id xbuf = ds4_gpu_tensor_buffer(x); id outbuf = ds4_gpu_tensor_buffer(out); @@ -5362,6 +6096,82 @@ int ds4_gpu_matmul_q8_0_tensor( return 1; } +static void ds4_gpu_mpp_compare_q8_0_matmul( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!ds4_gpu_mpp_compare_route_matches("q8")) return; + const uint64_t out_bytes = n_tok * out_dim * sizeof(float); + ds4_gpu_tensor *ref = ds4_gpu_tensor_alloc(out_bytes); + ds4_gpu_tensor *cand = ds4_gpu_mpp_compare_snapshot_buffer(ds4_gpu_tensor_buffer(out), + ds4_gpu_tensor_offset(out), + out_bytes); + if (!ref || !cand) { + ds4_gpu_tensor_free(ref); + ds4_gpu_tensor_free(cand); + return; + } + + if (ds4_gpu_matmul_q8_0_legacy_tensor(ref, model_map, model_size, + weight_offset, in_dim, out_dim, + x, n_tok)) { + char fallback[128]; + snprintf(fallback, sizeof(fallback), + "q8 weight_off=%llu in=%llu out=%llu tok=%llu", + (unsigned long long)weight_offset, + (unsigned long long)in_dim, + (unsigned long long)out_dim, + (unsigned long long)n_tok); + ds4_gpu_mpp_compare_register("q8", + fallback, + ref, + cand, + n_tok * out_dim, + n_tok, + out_dim, + in_dim); + if (!g_batch_cb) ds4_gpu_mpp_compare_drain("q8 compare"); + } + ds4_gpu_tensor_free(cand); + ds4_gpu_tensor_free(ref); +} + +int ds4_gpu_matmul_q8_0_tensor( + ds4_gpu_tensor *out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if ((in_dim & 31u) != 0 || + in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok > UINT32_MAX) { + return 0; + } + + if (ds4_gpu_can_use_mpp_q8_0_matmul(n_tok)) { + if (ds4_gpu_matmul_q8_0_mpp_tensor(out, model_map, model_size, weight_offset, + in_dim, out_dim, x, n_tok)) { + ds4_gpu_mpp_compare_q8_0_matmul(out, model_map, model_size, + weight_offset, in_dim, out_dim, + x, n_tok); + return 1; + } + ds4_gpu_warn_mpp_fallback(); + } + + return ds4_gpu_matmul_q8_0_legacy_tensor(out, model_map, model_size, + weight_offset, in_dim, out_dim, + x, n_tok); +} + int ds4_gpu_matmul_q8_0_mpp_tensor( ds4_gpu_tensor *out, const void *model_map, @@ -5402,10 +6212,21 @@ int ds4_gpu_matmul_q8_0_mpp_tensor( id wbuf = ds4_gpu_wrap_model_range(model_map, model_size, weight_offset, weight_bytes, &inner_offset); if (!wbuf) return 0; + const uint32_t tile_n = ds4_gpu_mpp_q8_0_tile_n(); + const bool direct_rhs = + (tile_n == 32u || tile_n == 64u) && + ds4_gpu_mpp_q8_0_direct_rhs(); const bool bc_inp = (in_dim % 32u) != 0; - const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + const bool bc_out = (out_dim % 64u) != 0 || (n_tok % tile_n) != 0; + const char *pipeline_name = direct_rhs ? + (tile_n == 64u ? + "kernel_mul_mm_q8_0_f32_mpp_direct_rhs_n64" : + "kernel_mul_mm_q8_0_f32_mpp_direct_rhs") : + (tile_n == 64u ? + "kernel_mul_mm_q8_0_f32_mpp_n64" : + "kernel_mul_mm_q8_0_f32_mpp"); id pipeline = - ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_q8_0_f32_mpp", bc_inp, bc_out); + ds4_gpu_get_mul_mm_pipeline(pipeline_name, bc_inp, bc_out); if (!pipeline) return 0; int owned = 0; @@ -5420,8 +6241,8 @@ int ds4_gpu_matmul_q8_0_mpp_tensor( [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:2]; [enc setBuffer:outbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; - [enc setThreadgroupMemoryLength:4096u atIndex:0]; - [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + [enc setThreadgroupMemoryLength:(direct_rhs ? 4096u : (tile_n == 64 ? 8192u : 6144u)) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + (NSUInteger)tile_n - 1u) / (NSUInteger)tile_n, ((NSUInteger)out_dim + 63u) / 64u, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; @@ -5624,11 +6445,20 @@ int ds4_gpu_matmul_f16_tensor( const bool bc_inp = (in_dim % 32u) != 0; const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; - /* Keep MPP F16 limited to the exact-safe ratio-2 compressor shape. */ - if (in_dim == 4096u && out_dim == 128u && !bc_inp && + const bool mpp_f16_shape = + in_dim == 4096u && !bc_inp && + (out_dim == 128u || + (ds4_gpu_mpp_f16_wide_matmul() && (out_dim % 64u) == 0)); + /* Keep wider compressor MPP opt-in until full-model drift and speed are measured. */ + if (mpp_f16_shape && ds4_gpu_use_mpp_f16_compressor_matmul()) { + const bool direct_rhs = ds4_gpu_mpp_f16_direct_rhs(); id pipeline = - ds4_gpu_get_mul_mm_pipeline("kernel_mul_mm_f16_f32_mpp", false, bc_out); + ds4_gpu_get_mul_mm_pipeline(direct_rhs ? + "kernel_mul_mm_f16_f32_mpp_direct_rhs" : + "kernel_mul_mm_f16_f32_mpp", + false, + bc_out); if (pipeline) { ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); @@ -5638,7 +6468,7 @@ int ds4_gpu_matmul_f16_tensor( [enc setBuffer:wbuf offset:(NSUInteger)inner_offset atIndex:1]; [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:2]; [enc setBuffer:outbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; - [enc setThreadgroupMemoryLength:4096u atIndex:0]; + [enc setThreadgroupMemoryLength:(direct_rhs ? 4096u : 6144u) atIndex:0]; [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, ((NSUInteger)out_dim + 63u) / 64u, 1) @@ -5687,12 +6517,93 @@ int ds4_gpu_matmul_f16_pair_tensor( const ds4_gpu_tensor *x, uint64_t n_tok) { if (!g_initialized && !ds4_gpu_init()) return 0; - if (in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok != 1 || (in_dim & 3u) != 0) return 0; + if (in_dim > UINT32_MAX || out_dim > UINT32_MAX || n_tok == 0 || (in_dim & 3u) != 0) return 0; @autoreleasepool { id xbuf = ds4_gpu_tensor_buffer(x); id outabuf = ds4_gpu_tensor_buffer(out_a); id outbbuf = ds4_gpu_tensor_buffer(out_b); + if (n_tok != 1) { + const bool use_wide_mpp_pair = ds4_gpu_mpp_f16_wide_matmul(); + const bool pair_shape = + in_dim == 4096u && (out_dim % 64u) == 0; + if (n_tok <= 8 || + !pair_shape || + !ds4_gpu_mpp_f16_pair_matmul() || + !ds4_gpu_use_mpp_f16_compressor_matmul()) { + return 0; + } + + const uint64_t x_bytes = n_tok * in_dim * sizeof(float); + const uint64_t out_bytes = n_tok * out_dim * sizeof(float); + if (!xbuf || !outabuf || !outbbuf || + ds4_gpu_tensor_bytes(x) < x_bytes || + ds4_gpu_tensor_bytes(out_a) < out_bytes || + ds4_gpu_tensor_bytes(out_b) < out_bytes) { + fprintf(stderr, "ds4: Metal F16 paired MPP matmul received undersized activation buffers\n"); + return 0; + } + + const uint64_t row_bytes = in_dim * sizeof(uint16_t); + const uint64_t weight_bytes = row_bytes * out_dim; + if (weight_a_offset > model_size || weight_bytes > model_size - weight_a_offset || + weight_b_offset > model_size || weight_bytes > model_size - weight_b_offset) { + fprintf(stderr, "ds4: Metal F16 paired MPP matmul range is outside the mapped model\n"); + return 0; + } + + uint64_t inner_a = 0; + uint64_t inner_b = 0; + id wabuf = ds4_gpu_wrap_model_range(model_map, model_size, + weight_a_offset, weight_bytes, + &inner_a); + id wbbuf = ds4_gpu_wrap_model_range(model_map, model_size, + weight_b_offset, weight_bytes, + &inner_b); + if (!wabuf || !wbbuf) return 0; + + const bool bc_out = (out_dim % 64u) != 0 || (n_tok % 32u) != 0; + id pipeline = + ds4_gpu_get_mul_mm_pipeline(use_wide_mpp_pair ? + "kernel_mul_mm_f16_f32_pair_mpp" : + "kernel_mul_mm_f16_f32_pair", + false, + bc_out); + if (!pipeline) return 0; + if (!g_mpp_f16_pair_reported) { + fprintf(stderr, "ds4: Metal paired F16 compressor matmul enabled%s\n", + use_wide_mpp_pair ? " with MPP wide route" : ""); + g_mpp_f16_pair_reported = 1; + } + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + ds4_gpu_mul_mm_args args = ds4_gpu_make_mm_args(in_dim, out_dim, n_tok, row_bytes); + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:wabuf offset:(NSUInteger)inner_a atIndex:1]; + [enc setBuffer:wbbuf offset:(NSUInteger)inner_b atIndex:2]; + [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:3]; + [enc setBuffer:outabuf offset:ds4_gpu_tensor_offset(out_a) atIndex:4]; + [enc setBuffer:outbbuf offset:ds4_gpu_tensor_offset(out_b) atIndex:5]; + const NSUInteger smem = use_wide_mpp_pair ? + (NSUInteger)((64u * 32u * 2u + 32u * 32u) * sizeof(uint16_t)) : + (NSUInteger)12288u; + [enc setThreadgroupMemoryLength:smem atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)n_tok + 31u) / 32u, + ((NSUInteger)out_dim + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal F16 paired matmul")) return 0; + return 1; + } + const uint64_t x_bytes = in_dim * sizeof(float); const uint64_t out_bytes = out_dim * sizeof(float); if (!xbuf || !outabuf || !outbbuf || @@ -8358,6 +9269,73 @@ static int ds4_gpu_encode_fill_f32_rows( return 1; } +static void ds4_gpu_mpp_compare_attn_out_low( + id cb, + const ds4_gpu_mul_mm_id_args *mm_args, + id out_a_buf, + NSUInteger out_a_inner, + const ds4_gpu_tensor *heads, + ds4_gpu_tensor *low, + uint32_t group_dim, + uint32_t rank, + uint32_t n_groups, + uint32_t n_tokens) { + if (!ds4_gpu_mpp_compare_route_matches("attn_out")) return; + const NSUInteger ids_bytes = (NSUInteger)n_tokens * (NSUInteger)n_groups * sizeof(int32_t); + id ids_buffer = ds4_gpu_new_transient_buffer(ids_bytes, "attention output compare group ids"); + ds4_gpu_tensor *ref = ds4_gpu_tensor_alloc((uint64_t)n_tokens * n_groups * rank * sizeof(float)); + ds4_gpu_tensor *cand = ds4_gpu_mpp_compare_snapshot_buffer(ds4_gpu_tensor_buffer(low), + ds4_gpu_tensor_offset(low), + (uint64_t)n_tokens * n_groups * rank * sizeof(float)); + if (!ids_buffer || !ref || !cand) { + ds4_gpu_tensor_free(ref); + ds4_gpu_tensor_free(cand); + return; + } + int32_t *ids = (int32_t *)[ids_buffer contents]; + for (uint32_t t = 0; t < n_tokens; t++) { + for (uint32_t group = 0; group < n_groups; group++) { + ids[(uint64_t)t * n_groups + group] = (int32_t)group; + } + } + + ds4_gpu_mul_mm_id_map_args map_args = + ds4_gpu_make_mul_mm_id_map_args(group_dim, + n_groups, + n_groups, + n_groups, + n_tokens); + id map_pipeline = + ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_groups)); + id legacy_pipeline = + ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q8_0_f32", false, false); + if (map_pipeline && legacy_pipeline && + ds4_gpu_encode_mul_mm_id(cb, + map_pipeline, + legacy_pipeline, + &map_args, + mm_args, + out_a_buf, + out_a_inner, + ds4_gpu_tensor_buffer(heads), + ds4_gpu_tensor_offset(heads), + ds4_gpu_tensor_buffer(ref), + ds4_gpu_tensor_offset(ref), + ids_buffer, + 0)) { + ds4_gpu_mpp_compare_register("attn_out", + "attn_out_low", + ref, + cand, + (uint64_t)n_tokens * n_groups * rank, + n_tokens, + (uint64_t)n_groups * rank, + group_dim); + } + ds4_gpu_tensor_free(cand); + ds4_gpu_tensor_free(ref); +} + int ds4_gpu_attention_output_q8_batch_tensor( ds4_gpu_tensor *out, ds4_gpu_tensor *low, @@ -8497,8 +9475,21 @@ int ds4_gpu_attention_output_q8_batch_tensor( n_groups, n_groups, n_tokens); + const uint32_t attn_out_tile_n = ds4_gpu_mpp_attn_out_tile_n(); + const bool attn_out_direct_rhs = + (attn_out_tile_n == 32u || attn_out_tile_n == 64u) && + ds4_gpu_mpp_attn_out_direct_rhs(); + const char *attn_out_pipeline_name = attn_out_direct_rhs ? + (attn_out_tile_n == 64u ? + "kernel_attn_out_low_q8_0_mpp_direct_rhs_n64" : + "kernel_attn_out_low_q8_0_mpp_direct_rhs") : + (attn_out_tile_n == 64u ? + "kernel_attn_out_low_q8_0_mpp_n64" : + "kernel_attn_out_low_q8_0_mpp"); id mm_pipeline = - ds4_gpu_get_mul_mm_id_pipeline("kernel_attn_out_low_q8_0_mpp", false, false); + ds4_gpu_get_mul_mm_id_pipeline(attn_out_pipeline_name, + false, + false); ok = ds4_gpu_encode_attn_out_low_q8_mpp(cb, mm_pipeline, &mm_args, @@ -8508,6 +9499,18 @@ int ds4_gpu_attention_output_q8_batch_tensor( ds4_gpu_tensor_offset(heads), ds4_gpu_tensor_buffer(low), ds4_gpu_tensor_offset(low)) != 0; + if (ok) { + ds4_gpu_mpp_compare_attn_out_low(cb, + &mm_args, + out_a_buf, + (NSUInteger)out_a_inner, + heads, + low, + (uint32_t)group_dim, + (uint32_t)rank, + n_groups, + n_tokens); + } if (!ok) { ds4_gpu_warn_mpp_fallback(); if (ds4_gpu_mul_mm_id_map0_name(n_groups) != NULL) { @@ -12071,31 +13074,139 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) { } static id ds4_gpu_routed_mm_pipeline(uint32_t type, bool use_mpp) { + const bool tile_n64 = use_mpp && ds4_gpu_mpp_moe_tile_n() == 64; + const bool fast_layout = use_mpp && !tile_n64 && ds4_gpu_mpp_moe_fast_layout(); switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32", false, use_mpp); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_iq2_xxs_f32_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_iq2_xxs_f32_n64" : + "kernel_mul_mm_id_iq2_xxs_f32", + false, + use_mpp); case DS4_METAL_TENSOR_Q2_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32", false, use_mpp); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q2_K_f32_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q2_K_f32_n64" : + "kernel_mul_mm_id_q2_K_f32", + false, + use_mpp); case DS4_METAL_TENSOR_Q4_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32", false, use_mpp); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q4_K_f32_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q4_K_f32_n64" : + "kernel_mul_mm_id_q4_K_f32", + false, + use_mpp); + default: + return nil; + } +} + +static id ds4_gpu_routed_mm_pair_mpp_pipeline(uint32_t type) { + switch (type) { + case DS4_METAL_TENSOR_IQ2_XXS: + return ds4_gpu_get_pipeline("kernel_mul_mm_id_iq2_xxs_f32_pair_mpp"); + case DS4_METAL_TENSOR_Q2_K: + return ds4_gpu_get_pipeline("kernel_mul_mm_id_q2_K_f32_pair_mpp"); + case DS4_METAL_TENSOR_Q4_K: + return ds4_gpu_get_pipeline("kernel_mul_mm_id_q4_K_f32_pair_mpp"); default: return nil; } } static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type, bool use_mpp) { + const bool tile_n64 = use_mpp && ds4_gpu_mpp_moe_tile_n() == 64; + const bool fast_layout = use_mpp && !tile_n64 && ds4_gpu_mpp_moe_fast_layout(); switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16", false, use_mpp); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_iq2_xxs_f16_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_iq2_xxs_f16_n64" : + "kernel_mul_mm_id_iq2_xxs_f16", + false, + use_mpp); case DS4_METAL_TENSOR_Q2_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false, use_mpp); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q2_K_f16_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q2_K_f16_n64" : + "kernel_mul_mm_id_q2_K_f16", + false, + use_mpp); case DS4_METAL_TENSOR_Q4_K: - return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false, use_mpp); + return ds4_gpu_get_mul_mm_id_pipeline(fast_layout ? + "kernel_mul_mm_id_q4_K_f16_fast_mpp" : + tile_n64 ? + "kernel_mul_mm_id_q4_K_f16_n64" : + "kernel_mul_mm_id_q4_K_f16", + false, + use_mpp); default: return nil; } } +static void ds4_gpu_mpp_compare_moe_mm( + const char *route, + const char *stage, + uint32_t type, + bool f16_rhs, + id cb, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id cand, + NSUInteger cand_off, + uint64_t elements, + uint64_t dim0, + uint64_t dim1, + uint64_t dim2) { + if (!ds4_gpu_mpp_compare_route_matches(route)) return; + if (elements == 0) return; + ds4_gpu_tensor *ref = ds4_gpu_tensor_alloc(elements * sizeof(float)); + ds4_gpu_tensor *cand_snapshot = ds4_gpu_mpp_compare_snapshot_buffer(cand, + cand_off, + elements * sizeof(float)); + if (!ref || !cand_snapshot) { + ds4_gpu_tensor_free(ref); + ds4_gpu_tensor_free(cand_snapshot); + return; + } + + id legacy_pipeline = f16_rhs ? + ds4_gpu_routed_mm_f16_rhs_pipeline(type, false) : + ds4_gpu_routed_mm_pipeline(type, false); + if (legacy_pipeline && + ds4_gpu_encode_mul_mm_id_mapped(cb, + legacy_pipeline, + mm_args, + src0, + src0_off, + src1, + src1_off, + ds4_gpu_tensor_buffer(ref), + ds4_gpu_tensor_offset(ref))) { + ds4_gpu_mpp_compare_register(route, + stage, + ref, + cand_snapshot, + elements, + dim0, + dim1, + dim2); + } + ds4_gpu_tensor_free(cand_snapshot); + ds4_gpu_tensor_free(ref); +} + static int ds4_gpu_encode_mul_mv_id( id cb, id pipeline, @@ -12387,7 +13498,7 @@ static int ds4_gpu_encode_mul_mm_id_map( return 1; } -static int ds4_gpu_encode_mul_mm_id_mapped( +static int ds4_gpu_encode_mul_mm_id_mapped_tile( id cb, id mm_pipeline, const ds4_gpu_mul_mm_id_args *mm_args, @@ -12396,13 +13507,15 @@ static int ds4_gpu_encode_mul_mm_id_mapped( id src1, NSUInteger src1_off, id dst, - NSUInteger dst_off) { + NSUInteger dst_off, + uint32_t tile_n) { if (!cb || !mm_pipeline || !mm_args || !src0 || !src1 || !dst || !g_moe_id_map_buffer || mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { return 0; } + if (tile_n != 64u) tile_n = 32u; const NSUInteger tpe_bytes = (NSUInteger)mm_args->ne02 * sizeof(int32_t); const NSUInteger hids_bytes = (NSUInteger)mm_args->ne02 * (NSUInteger)mm_args->ne21 * sizeof(int32_t); @@ -12419,6 +13532,53 @@ static int ds4_gpu_encode_mul_mm_id_mapped( [enc setBuffer:g_moe_id_map_buffer offset:0 atIndex:3]; [enc setBuffer:g_moe_id_map_buffer offset:tpe_bytes atIndex:4]; [enc setBuffer:dst offset:dst_off atIndex:5]; + [enc setThreadgroupMemoryLength:(tile_n == 64u ? 16384u : 8192u) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + (NSUInteger)tile_n - 1u) / (NSUInteger)tile_n, + ((NSUInteger)mm_args->ne0 + 63u) / 64u, + (NSUInteger)mm_args->ne02) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + +static int ds4_gpu_encode_mul_mm_id_pair_mpp( + id cb, + id pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0_gate, + NSUInteger src0_gate_off, + id src0_up, + NSUInteger src0_up_off, + id src1, + NSUInteger src1_off, + id dst_gate, + NSUInteger dst_gate_off, + id dst_up, + NSUInteger dst_up_off) { + if (!cb || !pipeline || !mm_args || !src0_gate || !src0_up || !src1 || + !dst_gate || !dst_up || !g_moe_id_map_buffer || + mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || + mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { + return 0; + } + + const NSUInteger tpe_bytes = (NSUInteger)mm_args->ne02 * sizeof(int32_t); + const NSUInteger hids_bytes = (NSUInteger)mm_args->ne02 * (NSUInteger)mm_args->ne21 * sizeof(int32_t); + if (tpe_bytes > NSUIntegerMax - hids_bytes || + g_moe_id_map_bytes < tpe_bytes + hids_bytes) { + return 0; + } + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:mm_args length:sizeof(*mm_args) atIndex:0]; + [enc setBuffer:src0_gate offset:src0_gate_off atIndex:1]; + [enc setBuffer:src0_up offset:src0_up_off atIndex:2]; + [enc setBuffer:src1 offset:src1_off atIndex:3]; + [enc setBuffer:g_moe_id_map_buffer offset:0 atIndex:4]; + [enc setBuffer:g_moe_id_map_buffer offset:tpe_bytes atIndex:5]; + [enc setBuffer:dst_gate offset:dst_gate_off atIndex:6]; + [enc setBuffer:dst_up offset:dst_up_off atIndex:7]; [enc setThreadgroupMemoryLength:8192u atIndex:0]; [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + 31u) / 32u, ((NSUInteger)mm_args->ne0 + 63u) / 64u, @@ -12428,6 +13588,28 @@ static int ds4_gpu_encode_mul_mm_id_mapped( return 1; } +static int ds4_gpu_encode_mul_mm_id_mapped( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id dst, + NSUInteger dst_off) { + return ds4_gpu_encode_mul_mm_id_mapped_tile(cb, + mm_pipeline, + mm_args, + src0, + src0_off, + src1, + src1_off, + dst, + dst_off, + 32u); +} + static int ds4_gpu_encode_attn_out_low_q8_mpp( id cb, id pipeline, @@ -12444,14 +13626,19 @@ static int ds4_gpu_encode_attn_out_low_q8_mpp( return 0; } + const uint32_t tile_n = ds4_gpu_mpp_attn_out_tile_n(); + const bool direct_rhs = + (tile_n == 32u || tile_n == 64u) && + ds4_gpu_mpp_attn_out_direct_rhs(); + id enc = ds4_gpu_compute_encoder(cb); [enc setComputePipelineState:pipeline]; [enc setBytes:mm_args length:sizeof(*mm_args) atIndex:0]; [enc setBuffer:src0 offset:src0_off atIndex:1]; [enc setBuffer:src1 offset:src1_off atIndex:2]; [enc setBuffer:dst offset:dst_off atIndex:3]; - [enc setThreadgroupMemoryLength:4096u atIndex:0]; - [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + 31u) / 32u, + [enc setThreadgroupMemoryLength:(direct_rhs ? 4096u : (tile_n == 64 ? 8192u : 6144u)) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + (NSUInteger)tile_n - 1u) / (NSUInteger)tile_n, ((NSUInteger)mm_args->ne0 + 63u) / 64u, (NSUInteger)mm_args->ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; @@ -13679,6 +14866,7 @@ int ds4_gpu_routed_moe_batch_tensor( id down_mv_pipeline = ds4_gpu_routed_mv_pipeline(down_type); id gate_mm_pipeline = nil; id up_mm_pipeline = nil; + id gate_up_pair_mm_pipeline = nil; id down_mm_pipeline = nil; if (gate_nr0 == 0 || down_nr0 == 0 || !gate_mv_pipeline || !down_mv_pipeline) { fprintf(stderr, "ds4: unsupported Metal routed batch MoE quant types gate=%u down=%u\n", @@ -13725,6 +14913,19 @@ int ds4_gpu_routed_moe_batch_tensor( */ const bool request_mid_f16 = !g_quality_mode && getenv("DS4_METAL_MOE_MID_F32") == NULL; + const uint32_t moe_mpp_tile_n = ds4_gpu_mpp_moe_tile_n(); + const uint32_t gate_mm_tile_n = + (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0 ? moe_mpp_tile_n : 32u; + const uint32_t up_mm_tile_n = + (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0 ? moe_mpp_tile_n : 32u; + const uint32_t down_mm_tile_n = + (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0 ? moe_mpp_tile_n : 32u; + const bool use_gate_up_pair_mpp = + ds4_gpu_mpp_moe_pair_gate_up() && + (moe_mpp_mask & (DS4_METAL_MOE_MPP_GATE | DS4_METAL_MOE_MPP_UP)) == + (DS4_METAL_MOE_MPP_GATE | DS4_METAL_MOE_MPP_UP) && + gate_mm_tile_n == 32u && + up_mm_tile_n == 32u; if (use_mm_id) { gate_map_args = ds4_gpu_make_mul_mm_id_map_args(expert_in_dim, 256, 1, n_expert, n_tokens); @@ -13739,16 +14940,22 @@ int ds4_gpu_routed_moe_batch_tensor( request_mid_f16 ? sizeof(uint16_t) : sizeof(float)); map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_expert)); - gate_mm_pipeline = ds4_gpu_routed_mm_pipeline( - gate_type, - (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0); - up_mm_pipeline = ds4_gpu_routed_mm_pipeline( - gate_type, - (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0); + if (use_gate_up_pair_mpp) { + gate_up_pair_mm_pipeline = ds4_gpu_routed_mm_pair_mpp_pipeline(gate_type); + } else { + gate_mm_pipeline = ds4_gpu_routed_mm_pipeline( + gate_type, + (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0); + up_mm_pipeline = ds4_gpu_routed_mm_pipeline( + gate_type, + (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0); + } down_mm_pipeline = request_mid_f16 ? ds4_gpu_routed_mm_f16_rhs_pipeline(down_type, (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0) : ds4_gpu_routed_mm_pipeline(down_type, (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0); - if (!map_pipeline || !gate_mm_pipeline || !up_mm_pipeline || !down_mm_pipeline) { + if (!map_pipeline || + (use_gate_up_pair_mpp ? !gate_up_pair_mm_pipeline : (!gate_mm_pipeline || !up_mm_pipeline)) || + !down_mm_pipeline) { return 0; } } @@ -13815,8 +15022,57 @@ int ds4_gpu_routed_moe_batch_tensor( selectedbuf, ds4_gpu_tensor_offset(selected)); DS4_METAL_PROFILE_MOE_STAGE("map"); - if (ok) { - ok = ds4_gpu_encode_mul_mm_id_mapped(cb, + if (ok && use_gate_up_pair_mpp) { + ok = ds4_gpu_encode_mul_mm_id_pair_mpp(cb, + gate_up_pair_mm_pipeline, + &gate_mm_args, + gate_buf, + (NSUInteger)gate_inner, + up_buf, + (NSUInteger)up_inner, + xbuf, + ds4_gpu_tensor_offset(x), + gatebuf, + ds4_gpu_tensor_offset(gate), + upbuf, + ds4_gpu_tensor_offset(up)); + if (ok) { + ds4_gpu_mpp_compare_moe_mm("moe_gate", + "moe_gate", + gate_type, + false, + cb, + &gate_mm_args, + gate_buf, + (NSUInteger)gate_inner, + xbuf, + ds4_gpu_tensor_offset(x), + gatebuf, + ds4_gpu_tensor_offset(gate), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + ds4_gpu_mpp_compare_moe_mm("moe_up", + "moe_up", + gate_type, + false, + cb, + &gate_mm_args, + up_buf, + (NSUInteger)up_inner, + xbuf, + ds4_gpu_tensor_offset(x), + upbuf, + ds4_gpu_tensor_offset(up), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + } + DS4_METAL_PROFILE_MOE_STAGE("gate_up_pair"); + } else if (ok) { + ok = ds4_gpu_encode_mul_mm_id_mapped_tile(cb, gate_mm_pipeline, &gate_mm_args, gate_buf, @@ -13824,11 +15080,30 @@ int ds4_gpu_routed_moe_batch_tensor( xbuf, ds4_gpu_tensor_offset(x), gatebuf, - ds4_gpu_tensor_offset(gate)); + ds4_gpu_tensor_offset(gate), + gate_mm_tile_n); + if (ok && (moe_mpp_mask & DS4_METAL_MOE_MPP_GATE) != 0) { + ds4_gpu_mpp_compare_moe_mm("moe_gate", + "moe_gate", + gate_type, + false, + cb, + &gate_mm_args, + gate_buf, + (NSUInteger)gate_inner, + xbuf, + ds4_gpu_tensor_offset(x), + gatebuf, + ds4_gpu_tensor_offset(gate), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + } DS4_METAL_PROFILE_MOE_STAGE("gate"); } - if (ok) { - ok = ds4_gpu_encode_mul_mm_id_mapped(cb, + if (ok && !use_gate_up_pair_mpp) { + ok = ds4_gpu_encode_mul_mm_id_mapped_tile(cb, up_mm_pipeline, &gate_mm_args, up_buf, @@ -13836,7 +15111,26 @@ int ds4_gpu_routed_moe_batch_tensor( xbuf, ds4_gpu_tensor_offset(x), upbuf, - ds4_gpu_tensor_offset(up)); + ds4_gpu_tensor_offset(up), + up_mm_tile_n); + if (ok && (moe_mpp_mask & DS4_METAL_MOE_MPP_UP) != 0) { + ds4_gpu_mpp_compare_moe_mm("moe_up", + "moe_up", + gate_type, + false, + cb, + &gate_mm_args, + up_buf, + (NSUInteger)up_inner, + xbuf, + ds4_gpu_tensor_offset(x), + upbuf, + ds4_gpu_tensor_offset(up), + (uint64_t)pair_rows * expert_mid_dim, + n_tokens, + (uint64_t)n_expert * expert_mid_dim, + expert_in_dim); + } DS4_METAL_PROFILE_MOE_STAGE("up"); } } else if (use_tiny_pair_mv) { @@ -14008,7 +15302,7 @@ int ds4_gpu_routed_moe_batch_tensor( down_smem, 2); } else if (use_mm_id) { - ok = ds4_gpu_encode_mul_mm_id_mapped(cb, + ok = ds4_gpu_encode_mul_mm_id_mapped_tile(cb, down_mm_pipeline, &down_mm_args, down_buf, @@ -14016,7 +15310,26 @@ int ds4_gpu_routed_moe_batch_tensor( midbuf, ds4_gpu_tensor_offset(mid), down_dst, - down_dst_off); + down_dst_off, + down_mm_tile_n); + if (ok && (moe_mpp_mask & DS4_METAL_MOE_MPP_DOWN) != 0) { + ds4_gpu_mpp_compare_moe_mm("moe_down", + "moe_down", + down_type, + request_mid_f16, + cb, + &down_mm_args, + down_buf, + (NSUInteger)down_inner, + midbuf, + ds4_gpu_tensor_offset(mid), + down_dst, + down_dst_off, + (uint64_t)pair_rows * out_dim, + n_tokens, + (uint64_t)n_expert * out_dim, + expert_mid_dim); + } } else { ok = ds4_gpu_encode_mul_mv_id(cb, down_mv_pipeline, diff --git a/ds4_server.c b/ds4_server.c index bc8abbbd..8fcdd627 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -7840,6 +7840,15 @@ static float parse_float_arg(const char *s, const char *opt, float minv, float m return v; } +static ds4_mpp_mode parse_mpp_mode_arg(const char *s) { + if (!strcmp(s, "auto")) return DS4_MPP_AUTO; + if (!strcmp(s, "on")) return DS4_MPP_ON; + if (!strcmp(s, "off")) return DS4_MPP_OFF; + server_log(DS4_LOG_DEFAULT, "ds4-server: invalid MPP mode: %s", s); + server_log(DS4_LOG_DEFAULT, "ds4-server: valid MPP modes are: auto, on, off"); + exit(2); +} + static const char *need_arg(int *i, int argc, char **argv, const char *opt) { if (*i + 1 >= argc) { server_log(DS4_LOG_DEFAULT, "ds4-server: missing value for %s", opt); @@ -7897,7 +7906,9 @@ static void usage(FILE *fp) { " -t, --threads N\n" " CPU helper threads for lightweight host-side work.\n" " --quality\n" - " Prefer exact kernels where faster approximate paths exist; MTP uses strict verification.\n" + " Prefer exact kernels where faster approximate paths exist; disables Metal 4 MPP routes; MTP uses strict verification.\n" + " --mpp MODE\n" + " Metal 4 MPP policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" " --dir-steering-file FILE\n" " Load one f32 direction vector per layer for directional steering.\n" " --dir-steering-ffn F\n" @@ -8020,6 +8031,8 @@ static server_config parse_options(int argc, char **argv) { c.default_tokens = parse_int_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) { c.engine.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg); + } else if (!strcmp(arg, "--mpp")) { + c.engine.mpp_mode = parse_mpp_mode_arg(need_arg(&i, argc, argv, arg)); } else if (!strcmp(arg, "--host")) { c.host = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--port")) { diff --git a/metal/dense.metal b/metal/dense.metal index 0d7af3ba..6400c69d 100644 --- a/metal/dense.metal +++ b/metal/dense.metal @@ -912,6 +912,7 @@ constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; #ifdef DS4_METAL_HAS_TENSOR template< + short NR0, short NR1, typename SA, typename SA_4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), typename T0, typename T0_4x4, typename T1> @@ -926,6 +927,125 @@ kernel void kernel_mul_mm_mpp( ushort sgitg [[simdgroup_index_in_threadgroup]]) { (void) sgitg; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + const int im = tgpig.z; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + threadgroup SA *sa = (threadgroup SA *)shmem; + threadgroup SA *sb = sa + NR0*NK; + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + + device const T1 *ptrB = (device const T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11/sizeof(T1); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (!FC_mul_mm_bc_out || r0 + row < M) { + if (is_same::value && FC_mul_mm_bc_inp) { + device const T0 *row_ptr = (device const T0 *)(srcA + args.nb01*(r0 + row) + offset0); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? (SA)row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos/(16*nl); + const short il = (k_pos/16)%nl; + device const block_q *row_ptr = (device const block_q *)(srcA + args.nb01*(r0 + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (SA)0; + } + } + } + for (int work = tiitg; work < NK*NR1; work += NUM_THREADS) { + const int col = work/NK; + const int k = work%NK; + if ((!FC_mul_mm_bc_out && !FC_mul_mm_bc_inp) || + (r1 + col < N && loop_k + k < K)) { + sb[col*NK + k] = (SA)ptrB[(uint64_t)(r1 + col)*strideB + loop_k + k]; + } else { + sb[col*NK + k] = (SA)0; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(0, 0); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_batch = (device float *)dst + im*N*M; + if (!FC_mul_mm_bc_out) { + device float *dst_tile = dst_batch + r0 + (uint64_t)r1*M; + auto tD = tensor(dst_tile, dextents(NR0, NR1), array({1, M})); + cT.store(tD); + } else { + auto tD = tensor(dst_batch, dextents(M, N), array({1, M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); + } +} + +typedef decltype(kernel_mul_mm_mpp<64, 32, half, half4x4, float4x4, 1, dequantize_f32, float, float4x4, float>) mul_mm_mpp_t; +typedef decltype(kernel_mul_mm_mpp<64, 64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>) mul_mm_mpp_q8_n64_t; + +template [[host_name("kernel_mul_mm_f16_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp<64, 32, half, half4x4, half4x4, 1, dequantize_f16, half, half4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp<64, 32, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp_n64")]] kernel mul_mm_mpp_q8_n64_t kernel_mul_mm_mpp<64, 64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; + +kernel void kernel_mul_mm_f16_f32_pair_mpp( + constant ds4_metal_args_mul_mm & args, + device const char * srcA0, + device const char * srcA1, + device const char * srcB, + device char * dst0, + device char * dst1, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + constexpr int NR0 = 64; constexpr int NR1 = 32; constexpr int NK = 32; @@ -943,6 +1063,126 @@ kernel void kernel_mul_mm_mpp( const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + threadgroup half *sa0 = (threadgroup half *)shmem; + threadgroup half *sa1 = sa0 + NR0*NK; + threadgroup half *sb = sa1 + NR0*NK; + auto tA0 = tensor(sa0, dextents(NK, NR0)); + auto tA1 = tensor(sa1, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + + device const float *ptrB = (device const float *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11/sizeof(float); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto c0 = mm.template get_destination_cooperative_tensor(); + auto c1 = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < c0.get_capacity(); ++i) { + if (c0.is_valid_element(i)) { + c0[i] = 0.0f; + c1[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (!FC_mul_mm_bc_out || r0 + row < M) { + device const half *row0 = (device const half *)(srcA0 + args.nb01*(r0 + row) + offset0); + device const half *row1 = (device const half *)(srcA1 + args.nb01*(r0 + row) + offset0); + FOR_UNROLL (short i = 0; i < 16; i++) { + const bool in_bounds = k_pos + i < K; + sa0[row*NK + k_base + i] = in_bounds ? row0[k_pos + i] : (half)0; + sa1[row*NK + k_base + i] = in_bounds ? row1[k_pos + i] : (half)0; + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa0[row*NK + k_base + i] = (half)0; + sa1[row*NK + k_base + i] = (half)0; + } + } + } + for (int work = tiitg; work < NK*NR1; work += NUM_THREADS) { + const int col = work/NK; + const int k = work%NK; + if (!FC_mul_mm_bc_out || (r1 + col < N && loop_k + k < K)) { + sb[col*NK + k] = (half)ptrB[(uint64_t)(r1 + col)*strideB + loop_k + k]; + } else { + sb[col*NK + k] = (half)0; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA0 = tA0.slice(0, 0); + auto mA1 = tA1.slice(0, 0); + auto mB = tB.slice(0, 0); + mm.run(mB, mA0, c0); + mm.run(mB, mA1, c1); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst0_batch = (device float *)dst0 + im*N*M; + device float *dst1_batch = (device float *)dst1 + im*N*M; + if (!FC_mul_mm_bc_out) { + device float *dst0_tile = dst0_batch + r0 + (uint64_t)r1*M; + device float *dst1_tile = dst1_batch + r0 + (uint64_t)r1*M; + auto tD0 = tensor(dst0_tile, dextents(NR0, NR1), array({1, M})); + auto tD1 = tensor(dst1_tile, dextents(NR0, NR1), array({1, M})); + c0.store(tD0); + c1.store(tD1); + } else { + auto tD0 = tensor(dst0_batch, dextents(M, N), array({1, M})); + auto tD1 = tensor(dst1_batch, dextents(M, N), array({1, M})); + auto mD0 = tD0.slice(r0, r1); + auto mD1 = tD1.slice(r0, r1); + c0.store(mD0); + c1.store(mD1); + } +} + +template< + short NR1, + typename SA, typename SA_4x4, typename block_q, short nl, + void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1> +kernel void kernel_mul_mm_mpp_direct_rhs( + constant ds4_metal_args_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + const int im = tgpig.z; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + threadgroup SA *sa = (threadgroup SA *)shmem; auto tA = tensor(sa, dextents(NK, NR0)); @@ -955,7 +1195,14 @@ kernel void kernel_mul_mm_mpp( matmul2d_descriptor::mode::multiply_accumulate), execution_simdgroups<4>> mm; - auto cT = mm.get_destination_cooperative_tensor(); + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } for (int loop_k = 0; loop_k < K; loop_k += NK) { for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { @@ -1003,10 +1250,12 @@ kernel void kernel_mul_mm_mpp( cT.store(mD); } -typedef decltype(kernel_mul_mm_mpp) mul_mm_mpp_t; +typedef decltype(kernel_mul_mm_mpp_direct_rhs<32, half, half4x4, float4x4, 1, dequantize_f32, float, float4x4, float>) mul_mm_mpp_direct_rhs_t; +typedef decltype(kernel_mul_mm_mpp_direct_rhs<64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>) mul_mm_mpp_direct_rhs_q8_n64_t; -template [[host_name("kernel_mul_mm_f16_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp; -template [[host_name("kernel_mul_mm_q8_0_f32_mpp")]] kernel mul_mm_mpp_t kernel_mul_mm_mpp; +template [[host_name("kernel_mul_mm_f16_f32_mpp_direct_rhs")]] kernel mul_mm_mpp_direct_rhs_t kernel_mul_mm_mpp_direct_rhs<32, half, half4x4, half4x4, 1, dequantize_f16, half, half4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp_direct_rhs")]] kernel mul_mm_mpp_direct_rhs_t kernel_mul_mm_mpp_direct_rhs<32, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; +template [[host_name("kernel_mul_mm_q8_0_f32_mpp_direct_rhs_n64")]] kernel mul_mm_mpp_direct_rhs_q8_n64_t kernel_mul_mm_mpp_direct_rhs<64, half, half4x4, block_q8_0, 2, dequantize_q8_0, float, float4x4, float>; #endif // Tiled matrix-matrix kernel used for prompt batches larger than 8. DS4 uses @@ -1213,6 +1462,242 @@ kernel void kernel_mul_mm( } } +kernel void kernel_mul_mm_f16_f32_pair( + constant ds4_metal_args_mul_mm & args, + device const char * src0_a, + device const char * src0_b, + device const char * src1, + device char * dst_a, + device char * dst_b, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup half * sa_a = (threadgroup half *)(shmem); + threadgroup half * sa_b = (threadgroup half *)(shmem + 4096); + threadgroup half * sb = (threadgroup half *)(shmem + 8192); + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int im = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1; + + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; + + const short il0 = (tiitg % NL0); + short il = il0; + + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il0; + + device const half4x4 * xa = (device const half4x4 *)(src0_a + args.nb01*(r0 + lr0) + offset0) + offset1; + device const half4x4 * xb = (device const half4x4 *)(src0_b + args.nb01*(r0 + lr0) + offset0) + offset1; + + const short iy = 8*(tiitg % NL1); + + device const float * y = (device const float *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1 + lr1) + + args.nb10*iy); + + simdgroup_half8x8 ma[4]; + simdgroup_half8x8 mb[2]; + + simdgroup_float8x8 mc_a[8]; + simdgroup_float8x8 mc_b[8]; + + for (short i = 0; i < 8; i++) { + mc_a[i] = make_filled_simdgroup_matrix(0.f); + mc_b[i] = make_filled_simdgroup_matrix(0.f); + } + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + half4x4 temp_a; + half4x4 temp_b; + dequantize_f16(xa, il, temp_a); + dequantize_f16(xb, il, temp_b); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + + const short ib = 8*sx + sy; + + *(sa_a + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; + *(sa_b + 64*ib + 8*ly + lx) = temp_b[i/4][i%4]; + } + + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + + const short ib = 4*sx + sy; + + *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (half) *((device float *) y + i) : 0; + } + } else { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short ly = (tiitg/NL1)%8; + + const short ib = 4*sx + sy; + + *(threadgroup half2x4 *)(sb + 64*ib + 8*ly) = (half2x4)(*((device float2x4 *) y)); + } + + il = (il + 2 < 1) ? il + 2 : il % 2; + xa = (il < 2) ? xa + 2 : xa; + xb = (il < 2) ? xb + 2 : xb; + + y += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup const half * lsma_a = (sa_a + 4*64*(sgitg%2)); + threadgroup const half * lsma_b = (sa_b + 4*64*(sgitg%2)); + threadgroup const half * lsmb = (sb + 2*64*(sgitg/2)); + + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma_a + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++) { + simdgroup_multiply_accumulate(mc_a[i], mb[i/4], ma[i%4], mc_a[i]); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma_b + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++) { + simdgroup_multiply_accumulate(mc_b[i], mb[i/4], ma[i%4], mc_b[i]); + } + + lsma_a += 8*64; + lsma_b += 8*64; + lsmb += 4*64; + } + } + + if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { + device float * C_a = (device float *) dst_a + + (r0 + 32*(sgitg & 1)) + + (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + device float * C_b = (device float *) dst_b + + (r0 + 32*(sgitg & 1)) + + (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc_a[i], C_a + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); + simdgroup_store(mc_b[i], C_b + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); + } + } else { + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float * temp_str = (threadgroup float *) shmem; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc_a[i], + temp_str + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0 + 8*(i%4) + 8*NR0*(i/4), + NR0, + 0, + false); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < nr1; j += NR1) { + device float * D = (device float *) dst_a + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*NR0); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < nr0/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < nr0; i++) { + *(D + i) = *(C + i); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc_b[i], + temp_str + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0 + 8*(i%4) + 8*NR0*(i/4), + NR0, + 0, + false); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < nr1; j += NR1) { + device float * D = (device float *) dst_b + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*NR0); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < nr0/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < nr0; i++) { + *(D + i) = *(C + i); + } + } + } + } +} + typedef decltype(kernel_mul_mm) mul_mm_t; // Host-visible prefill matmul variants for F16 and Q8_0 weights. diff --git a/metal/moe.metal b/metal/moe.metal index 0cfd31ce..a4360fe6 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -1549,7 +1549,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_ // Batched routed-expert matmul. It reads the expert-major map produced above, // loads selected expert weights, and writes results back to token-major slots // so the DS4 FFN can apply SwiGLU, weighting, and the down projection. -template +template kernel void kernel_mul_mm_id( constant ds4_metal_args_mul_mm_id & args, device const char * src0, @@ -1569,7 +1569,6 @@ kernel void kernel_mul_mm_id( #endif constexpr int NR0 = 64; - constexpr int NR1 = 32; constexpr int NK = 32; constexpr int NL0 = NK/16; @@ -1590,6 +1589,7 @@ kernel void kernel_mul_mm_id( const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; + const bool full_mpp_tile = nr0 == NR0 && nr1 == NR1 && (args.ne00 % NK) == 0; const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; @@ -1627,14 +1627,21 @@ kernel void kernel_mul_mm_id( } #ifdef DS4_METAL_HAS_TENSOR auto tA = tensor(sa, dextents(NK, NR0)); - auto tB = tensor(sb, dextents(NR1, NK)); + auto tB = tensor(sb, dextents(NK, NR1)); matmul2d< matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate), execution_simdgroups<4>> mm; - auto cT = mm.get_destination_cooperative_tensor(); + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } #endif for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { @@ -1650,7 +1657,8 @@ kernel void kernel_mul_mm_id( const short lx = i%8; const short ly = (tiitg/NL0)%8; - *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + *(sa + NK*(8*sy + ly) + 8*sx + lx) = + full_mpp_tile || loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; } else #endif { @@ -1692,6 +1700,32 @@ kernel void kernel_mul_mm_id( } if (FC_mul_mm_bc_inp) { +#ifdef DS4_METAL_HAS_TENSOR + if (FC_mul_mm_id_mpp) { + for (short tile_row = 0; tile_row < NR1; tile_row += 32) { + const short t = (short)tiitg + tile_row*4; + const short row = t/NL1; + const short sx = t%NL1; + const short sy = row/8; + const short lx = 0; + const short ly = row%8; + const int idb = (full_mpp_tile || row < nr1) ? ids_i32[im*args.ne21 + r1 + row] : 0; + const short i11b = (idb % args.ne20) % args.ne11; + const short i12b = (idb / args.ne20); + device const T1 *yb = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12b + + args.nb11*i11b + + args.nb10*(loop_k + 8*sx)); + + FOR_UNROLL (short i = 0; i < 8; ++i) { + *(sb + NK*(8*sy + ly) + 8*sx + lx + i) = + full_mpp_tile || (row < nr1 && loop_k + 8*sx + i < args.ne00) ? (S1) *(yb + i) : 0; + } + } + } else +#endif + { for (short i = 0; i < 8; ++i) { const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; @@ -1699,29 +1733,44 @@ kernel void kernel_mul_mm_id( const short lx = i; const short ly = (tiitg/NL1)%8; -#ifdef DS4_METAL_HAS_TENSOR - if (FC_mul_mm_id_mpp) { - *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; - } else -#endif - { const short ib = 4*sx + sy; *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; - } + } } } else { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - const short ly = (tiitg/NL1)%8; - #ifdef DS4_METAL_HAS_TENSOR if (FC_mul_mm_id_mpp) { - *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); + for (short tile_row = 0; tile_row < NR1; tile_row += 32) { + const short t = (short)tiitg + tile_row*4; + const short row = t/NL1; + const short sx = t%NL1; + const short sy = row/8; + const short ly = row%8; + const int idb = (full_mpp_tile || row < nr1) ? ids_i32[im*args.ne21 + r1 + row] : 0; + const short i11b = (idb % args.ne20) % args.ne11; + const short i12b = (idb / args.ne20); + device const T1 *yb = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12b + + args.nb11*i11b + + args.nb10*loop_k); + + if (full_mpp_tile || row < nr1) { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = + (S1_2x4)(*((device T1_2x4 *) yb + sx)); + } else { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(0); + } + } } else #endif { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short ly = (tiitg/NL1)%8; + const short ib = 4*sx + sy; *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); @@ -1813,20 +1862,405 @@ kernel void kernel_mul_mm_id( } } -typedef decltype(kernel_mul_mm_id) mul_mm_id; -typedef decltype(kernel_mul_mm_id) mul_mm_id_f16_rhs; +#ifdef DS4_METAL_HAS_TENSOR +template +kernel void kernel_mul_mm_id_pair_mpp( + constant ds4_metal_args_mul_mm_id & args, + device const char * src0_gate, + device const char * src0_up, + device const char * src1, + device const char * htpe, + device const char * hids, + device char * dst_gate, + device char * dst_up, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); + threadgroup float *sc = (threadgroup float *)shmem; + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int im = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); + const int32_t neh1 = tpe_u32[im]; + if (r1 >= neh1) { + return; + } + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short il0 = (tiitg % NL0); + short il = il0; + + const int i13 = 0; + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; + const short offset1 = il0/nl; + device const block_q * x_gate = + (device const block_q *)(src0_gate + args.nb01*(r0 + lr0) + offset0) + offset1; + device const block_q * x_up = + (device const block_q *)(src0_up + args.nb01*(r0 + lr0) + offset0) + offset1; + + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cGate = mm.template get_destination_cooperative_tensor(); + auto cUp = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cGate.get_capacity(); ++i) { + if (cGate.is_valid_element(i)) cGate[i] = 0.0f; + if (cUp.is_valid_element(i)) cUp[i] = 0.0f; + } + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + S0_4x4 temp_gate; + dequantize_func(x_gate, il, temp_gate); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_gate[i/4][i%4]; + } + + const short row = ((short)tiitg)/NL1; + const short sx = ((short)tiitg)%NL1; + const short sy = row/8; + const short ly = row%8; + const int idb = row < nr1 ? ids_i32[im*args.ne21 + r1 + row] : 0; + const short i11b = (idb % args.ne20) % args.ne11; + const short i12b = (idb / args.ne20); + device const T1 *yb = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12b + + args.nb11*i11b + + args.nb10*loop_k); + + if (row < nr1) { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = + (S1_2x4)(*((device T1_2x4 *) yb + sx)); + } else { + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(0); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + mm.run(sB, sA, cGate); + + S0_4x4 temp_up; + dequantize_func(x_up, il, temp_up); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short ax = 2*il0 + i/8; + const short ay = (tiitg/NL0)/8; + const short lx = i%8; + const short ly2 = (tiitg/NL0)%8; + *(sa + NK*(8*ay + ly2) + 8*ax + lx) = temp_up[i/4][i%4]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sA = tA.slice(0, 0); + sB = tB.slice(0, 0); + mm.run(sB, sA, cUp); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x_gate = (il < 2) ? x_gate + (2 + nl - 1)/nl : x_gate; + x_up = (il < 2) ? x_up + (2 + nl - 1)/nl : x_up; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto tC = tensor(sc, dextents(NR0, NR1)); + cGate.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < nr1; j += 4) { + const int idj = ids_i32[im*args.ne21 + r1 + j]; + const short ide = idj % args.ne20; + const short idt = idj / args.ne20; + device float * D = (device float *) dst_gate + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + threadgroup float * C = (threadgroup float *) shmem + j*NR0; + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < nr0/4; i += 32) *(D4 + i) = *(C4 + i); + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) *(D + i) = *(C + i); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + cUp.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < nr1; j += 4) { + const int idj = ids_i32[im*args.ne21 + r1 + j]; + const short ide = idj % args.ne20; + const short idt = idj / args.ne20; + device float * D = (device float *) dst_up + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + threadgroup float * C = (threadgroup float *) shmem + j*NR0; + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < nr0/4; i += 32) *(D4 + i) = *(C4 + i); + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) *(D + i) = *(C + i); + } +} +#endif + +typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>) mul_mm_id; +typedef decltype(kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>) mul_mm_id_n64; +typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>) mul_mm_id_f16_rhs; +typedef decltype(kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>) mul_mm_id_f16_rhs_n64; + +#ifdef DS4_METAL_HAS_TENSOR +// Diagnostic-only old MPP tensor layout from the first Metal 4 PR. It is kept +// behind DS4_METAL_MPP_MOE_FAST_LAYOUT so we can measure whether the old kernel +// shape can be recovered for routes that already pass full-model equivalence. +template +kernel void kernel_mul_mm_id_mpp_fast_layout( + constant ds4_metal_args_mul_mm_id & args, + device const char * src0, + device const char * src1, + device const char * htpe, + device const char * hids, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + (void)sgitg; + + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); + threadgroup float *sc = (threadgroup float *)shmem; + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int im = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); + + const int32_t neh1 = tpe_u32[im]; + + if (r1 >= neh1) { + return; + } + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; + + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; + + const short il0 = (tiitg % NL0); + short il = il0; + + const int id = ids_i32[im*args.ne21 + r1 + lr1]; + + const short i11 = (id % args.ne20) % args.ne11; + const short i12 = (id / args.ne20); + const short i13 = 0; + + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; + const short offset1 = il0/nl; + + device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; + + const short iy = 8*(tiitg % NL1); + + device const T1 * y = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*iy); + + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NR1, NK)); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = + loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; + } + } + + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + const short lx = i; + const short ly = (tiitg/NL1)%8; + + *(sb + NK*(8*sy + ly) + 8*sx + lx) = + loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } + } else { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + const short ly = (tiitg/NL1)%8; + + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = + (S1_2x4)(*((device T1_2x4 *) y)); + } + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; + + y += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + mm.run(sB, sA, cT); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto tC = tensor(sc, dextents(NR0, NR1)); + cT.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = tiitg/32; j < nr1; j += 4) { + const int idj = ids_i32[im*args.ne21 + r1 + j]; + + const short ide = idj % args.ne20; + const short idt = idj / args.ne20; + + device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = (threadgroup float *) shmem + j*NR0; + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < nr0/4; i += 32) { + *(D4 + i) = *(C4 + i); + } + + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) { + *(D + i) = *(C + i); + } + } +} + +typedef decltype(kernel_mul_mm_id_mpp_fast_layout) mul_mm_id_fast_layout; +typedef decltype(kernel_mul_mm_id_mpp_fast_layout) mul_mm_id_fast_layout_f16_rhs; +typedef decltype(kernel_mul_mm_id_pair_mpp) mul_mm_id_pair_mpp_t; +#endif // Host-visible batched MoE matmul variants for the DS4 quant formats. -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q8_0_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n64")]] kernel mul_mm_id_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q8_0_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n64")]] kernel mul_mm_id_f16_rhs_n64 kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +#ifdef DS4_METAL_HAS_TENSOR +template [[host_name("kernel_mul_mm_id_q8_0_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q2_K_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q4_K_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_fast_mpp")]] kernel mul_mm_id_fast_layout kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q8_0_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q2_K_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_q4_K_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_fast_mpp")]] kernel mul_mm_id_fast_layout_f16_rhs kernel_mul_mm_id_mpp_fast_layout; + +template [[host_name("kernel_mul_mm_id_q8_0_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +template [[host_name("kernel_mul_mm_id_q2_K_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +template [[host_name("kernel_mul_mm_id_q4_K_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_pair_mpp")]] kernel mul_mm_id_pair_mpp_t kernel_mul_mm_id_pair_mpp; +#endif #ifdef DS4_METAL_HAS_TENSOR +template kernel void kernel_attn_out_low_q8_0_mpp( constant ds4_metal_args_mul_mm_id & args, device const char * srcA, @@ -1839,7 +2273,6 @@ kernel void kernel_attn_out_low_q8_0_mpp( (void) sgitg; constexpr int NR0 = 64; - constexpr int NR1 = 32; constexpr int NK = 32; constexpr int NL = NK/16; constexpr int NUM_THREADS = 128; @@ -1851,6 +2284,115 @@ kernel void kernel_attn_out_low_q8_0_mpp( const int group = tgpig.z; const int r0 = tgpig.y*NR0; const int r1 = tgpig.x*NR1; + const bool full_tile = r0 + NR0 <= M && r1 + NR1 <= N && (K % NK) == 0; + + threadgroup half *sa = (threadgroup half *)shmem; + threadgroup half *sb = sa + NR0*NK; + auto tA = tensor(sa, dextents(NK, NR0)); + auto tB = tensor(sb, dextents(NK, NR1)); + + device const float *ptrB = (device const float *)(srcB + args.nb11*group); + const int strideB = args.nb12/sizeof(float); + + matmul2d< + matmul2d_descriptor(NR1, NR0, NK, false, true, false, + matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; + + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } + + for (int loop_k = 0; loop_k < K; loop_k += NK) { + for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { + const int row = work/NL; + const int k_chunk = work%NL; + const int k_pos = loop_k + k_chunk*16; + const short k_base = k_chunk*16; + + if (full_tile || r0 + row < M) { + const int block_idx = k_pos/32; + const short il = (k_pos/16)%2; + device const block_q8_0 *row_ptr = + (device const block_q8_0 *)(srcA + args.nb01*(r0 + row) + group*args.nb02); + + half4x4 temp_a; + dequantize_q8_0(row_ptr + block_idx, il, temp_a); + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (full_tile || k_pos + i < K) ? temp_a[i/4][i%4] : (half)0; + } + } else { + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row*NK + k_base + i] = (half)0; + } + } + } + for (int work = tiitg; work < NK*NR1; work += NUM_THREADS) { + const int col = work/NK; + const int k = work%NK; + if (full_tile || (r1 + col < N && loop_k + k < K)) { + sb[col*NK + k] = (half)ptrB[(uint64_t)(r1 + col)*strideB + loop_k + k]; + } else { + sb[col*NK + k] = (half)0; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto mA = tA.slice(0, 0); + auto mB = tB.slice(0, 0); + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float *dst_group = (device float *)dst + group*M; + if (full_tile) { + device float *dst_tile = dst_group + r0 + (uint64_t)r1*G*M; + auto tD = tensor(dst_tile, dextents(NR0, NR1), array({1, G*M})); + cT.store(tD); + } else { + auto tD = tensor(dst_group, dextents(M, N), array({1, G*M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); + } +} + +typedef decltype(kernel_attn_out_low_q8_0_mpp<32>) attn_out_low_q8_0_mpp_t; + +template [[host_name("kernel_attn_out_low_q8_0_mpp")]] kernel attn_out_low_q8_0_mpp_t kernel_attn_out_low_q8_0_mpp<32>; +template [[host_name("kernel_attn_out_low_q8_0_mpp_n64")]] kernel attn_out_low_q8_0_mpp_t kernel_attn_out_low_q8_0_mpp<64>; + +template +kernel void kernel_attn_out_low_q8_0_mpp_direct_rhs( + constant ds4_metal_args_mul_mm_id & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + constexpr int NR0 = 64; + constexpr int NK = 32; + constexpr int NL = NK/16; + constexpr int NUM_THREADS = 128; + + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne21; + const int G = args.ne1; + const int group = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + const bool full_tile = r0 + NR0 <= M && r1 + NR1 <= N && (K % NK) == 0; threadgroup half *sa = (threadgroup half *)shmem; auto tA = tensor(sa, dextents(NK, NR0)); @@ -1864,7 +2406,14 @@ kernel void kernel_attn_out_low_q8_0_mpp( matmul2d_descriptor::mode::multiply_accumulate), execution_simdgroups<4>> mm; - auto cT = mm.get_destination_cooperative_tensor(); + auto cT = mm.template get_destination_cooperative_tensor(); + + #pragma unroll + for (uint16_t i = 0; i < cT.get_capacity(); ++i) { + if (cT.is_valid_element(i)) { + cT[i] = 0.0f; + } + } for (int loop_k = 0; loop_k < K; loop_k += NK) { for (int work = tiitg; work < NR0*NL; work += NUM_THREADS) { @@ -1873,7 +2422,7 @@ kernel void kernel_attn_out_low_q8_0_mpp( const int k_pos = loop_k + k_chunk*16; const short k_base = k_chunk*16; - if (r0 + row < M) { + if (full_tile || r0 + row < M) { const int block_idx = k_pos/32; const short il = (k_pos/16)%2; device const block_q8_0 *row_ptr = @@ -1882,7 +2431,7 @@ kernel void kernel_attn_out_low_q8_0_mpp( half4x4 temp_a; dequantize_q8_0(row_ptr + block_idx, il, temp_a); FOR_UNROLL (short i = 0; i < 16; i++) { - sa[row*NK + k_base + i] = (k_pos + i < K) ? temp_a[i/4][i%4] : (half)0; + sa[row*NK + k_base + i] = (full_tile || k_pos + i < K) ? temp_a[i/4][i%4] : (half)0; } } else { FOR_UNROLL (short i = 0; i < 16; i++) { @@ -1901,10 +2450,23 @@ kernel void kernel_attn_out_low_q8_0_mpp( } device float *dst_group = (device float *)dst + group*M; - auto tD = tensor(dst_group, dextents(M, N), array({1, G*M})); - auto mD = tD.slice(r0, r1); - cT.store(mD); + if (full_tile) { + device float *dst_tile = dst_group + r0 + (uint64_t)r1*G*M; + auto tD = tensor(dst_tile, dextents(NR0, NR1), array({1, G*M})); + cT.store(tD); + } else { + auto tD = tensor(dst_group, dextents(M, N), array({1, G*M})); + auto mD = tD.slice(r0, r1); + cT.store(mD); + } } + +typedef decltype(kernel_attn_out_low_q8_0_mpp_direct_rhs<32>) attn_out_low_q8_0_mpp_direct_rhs_t; +typedef decltype(kernel_attn_out_low_q8_0_mpp_direct_rhs<64>) attn_out_low_q8_0_mpp_direct_rhs_n64_t; + +template [[host_name("kernel_attn_out_low_q8_0_mpp_direct_rhs")]] kernel attn_out_low_q8_0_mpp_direct_rhs_t kernel_attn_out_low_q8_0_mpp_direct_rhs<32>; +template [[host_name("kernel_attn_out_low_q8_0_mpp_direct_rhs_n64")]] kernel attn_out_low_q8_0_mpp_direct_rhs_n64_t kernel_attn_out_low_q8_0_mpp_direct_rhs<64>; + #endif #undef QK_NL diff --git a/tests/ds4_test.c b/tests/ds4_test.c index dd45ba78..0c9fd1cf 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -150,10 +150,10 @@ static void test_metal_f16_matvec_fast_nr0_4(void) { free(weights_raw); } -static void test_metal_q8_0_mpp_matmul(void) { - const uint32_t in_dim = 128; - const uint32_t out_dim = 96; - const uint32_t n_tok = 48; +static void test_metal_q8_0_mpp_matmul_case(const char *label, + uint32_t in_dim, + uint32_t out_dim, + uint32_t n_tok) { const uint64_t blocks = in_dim / 32; const uint64_t row_bytes = blocks * 34; const uint64_t weight_bytes = (uint64_t)out_dim * row_bytes; @@ -226,7 +226,8 @@ static void test_metal_q8_0_mpp_matmul(void) { int have_mpp = ds4_gpu_matmul_q8_0_mpp_tensor( out_mpp, weights_raw, weight_alloc, 0, in_dim, out_dim, x, n_tok); if (!have_mpp) { - fprintf(stderr, "ds4-test: skipping MPP Q8_0 matmul; Metal 4 tensor API unavailable\n"); + fprintf(stderr, "ds4-test: skipping MPP Q8_0 matmul %s; Metal 4 tensor API unavailable\n", + label); free(x_host); free(ref_host); free(mpp_host); @@ -241,17 +242,21 @@ static void test_metal_q8_0_mpp_matmul(void) { TEST_ASSERT(ds4_gpu_tensor_read(out_mpp, 0, mpp_host, out_bytes) != 0); float max_abs = 0.0f; + double sumsq = 0.0; uint64_t max_index = 0; for (uint64_t i = 0; i < (uint64_t)n_tok * out_dim; i++) { - float err = fabsf(mpp_host[i] - ref_host[i]); + const float err = fabsf(mpp_host[i] - ref_host[i]); + sumsq += (double)err * (double)err; if (err > max_abs) { max_abs = err; max_index = i; } } + const float rms = (float)sqrt(sumsq / (double)((uint64_t)n_tok * out_dim)); if (max_abs >= 0.10f) { - fprintf(stderr, "ds4-test: MPP Q8_0 matmul max_abs=%f at token=%llu out=%llu ref=%f mpp=%f\n", - max_abs, + fprintf(stderr, + "ds4-test: MPP Q8_0 matmul %s in=%u out=%u tok=%u max_abs=%f rms=%f at token=%llu out=%llu ref=%f mpp=%f\n", + label, in_dim, out_dim, n_tok, max_abs, rms, (unsigned long long)(max_index / out_dim), (unsigned long long)(max_index % out_dim), ref_host[max_index], @@ -268,6 +273,13 @@ static void test_metal_q8_0_mpp_matmul(void) { free(weights_raw); } +static void test_metal_q8_0_mpp_matmul(void) { + test_metal_q8_0_mpp_matmul_case("small_partial48", 128, 96, 48); + test_metal_q8_0_mpp_matmul_case("medium_partial48", 512, 256, 48); + test_metal_q8_0_mpp_matmul_case("modelish_full32", 4096, 256, 32); + test_metal_q8_0_mpp_matmul_case("modelish_partial48", 4096, 256, 48); +} + static void test_metal_kernel_group(void) { test_metal_f16_matvec_fast_nr0_4(); test_metal_q8_0_mpp_matmul(); @@ -669,6 +681,563 @@ static void test_official_logprob_vectors(void) { fclose(fp); } +#define TEST_MPP_EQ_MAX_CASES 8 +#define TEST_MPP_EQ_TOPK 20 +#define TEST_MPP_EQ_TOP5 5 +#define TEST_MPP_EQ_DELTAS 5 + +typedef struct { + char id[96]; + int ctx; + int vocab_size; + int gen_steps; + ds4_tokens prompt; + float *ref_logits; + int ref_gen[TEST_VEC_MAX_STEPS]; + int ref_gen_len; +} test_mpp_eq_case; + +typedef struct { + int ref_top1; + int cand_top1; + int overlap; + int top5_overlap; + int max_rank_delta; + int nonfinite; + float rms; + float max_abs; + float top20_max_abs; + bool same_top1; + bool pass; +} test_mpp_eq_result; + +typedef struct { + const char *label; + int cases; + int capture_failures; + int logits_failures; + int greedy_failures; + int top1_mismatches; + int min_overlap; + int min_top5_overlap; + int worst_rank_delta; + float worst_rms; + float worst_max_abs; + float worst_top20_max_abs; +} test_mpp_eq_summary; + +static void test_mpp_eq_case_free(test_mpp_eq_case *tc) { + if (!tc) return; + ds4_tokens_free(&tc->prompt); + free(tc->ref_logits); + memset(tc, 0, sizeof(*tc)); +} + +static void test_logits_topk(const float *logits, int n, int *out, int k) { + for (int i = 0; i < k; i++) out[i] = -1; + for (int id = 0; id < n; id++) { + const float v = logits[id]; + if (!isfinite(v)) continue; + for (int j = 0; j < k; j++) { + if (out[j] < 0 || v > logits[out[j]]) { + for (int l = k - 1; l > j; l--) out[l] = out[l - 1]; + out[j] = id; + break; + } + } + } +} + +static bool test_topk_contains(const int *top, int k, int id) { + for (int i = 0; i < k; i++) { + if (top[i] == id) return true; + } + return false; +} + +static int test_topk_rank(const int *top, int k, int id) { + for (int i = 0; i < k; i++) { + if (top[i] == id) return i; + } + return -1; +} + +static void test_note_delta(int *ids, float *ref_vals, float *cand_vals, + float *abs_vals, int id, float ref, float cand) { + const float abs_delta = fabsf(cand - ref); + for (int i = 0; i < TEST_MPP_EQ_DELTAS; i++) { + if (ids[i] < 0 || abs_delta > abs_vals[i]) { + for (int j = TEST_MPP_EQ_DELTAS - 1; j > i; j--) { + ids[j] = ids[j - 1]; + ref_vals[j] = ref_vals[j - 1]; + cand_vals[j] = cand_vals[j - 1]; + abs_vals[j] = abs_vals[j - 1]; + } + ids[i] = id; + ref_vals[i] = ref; + cand_vals[i] = cand; + abs_vals[i] = abs_delta; + return; + } + } +} + +static float test_top_union_max_abs(const float *ref, const float *cand, + const int *ref_top, const int *cand_top, int k) { + float max_abs = 0.0f; + for (int i = 0; i < k; i++) { + if (ref_top[i] >= 0) { + const float d = fabsf(cand[ref_top[i]] - ref[ref_top[i]]); + if (d > max_abs) max_abs = d; + } + if (cand_top[i] >= 0 && !test_topk_contains(ref_top, k, cand_top[i])) { + const float d = fabsf(cand[cand_top[i]] - ref[cand_top[i]]); + if (d > max_abs) max_abs = d; + } + } + return max_abs; +} + +static test_mpp_eq_result test_compare_mpp_logits(const test_mpp_eq_case *tc, + const float *cand_logits, + bool assert_thresholds) { + int ref_top[TEST_MPP_EQ_TOPK]; + int cand_top[TEST_MPP_EQ_TOPK]; + test_logits_topk(tc->ref_logits, tc->vocab_size, ref_top, TEST_MPP_EQ_TOPK); + test_logits_topk(cand_logits, tc->vocab_size, cand_top, TEST_MPP_EQ_TOPK); + + int overlap = 0; + int top5_overlap = 0; + int max_rank_delta = 0; + for (int i = 0; i < TEST_MPP_EQ_TOPK; i++) { + const int cand_rank = test_topk_rank(cand_top, TEST_MPP_EQ_TOPK, ref_top[i]); + if (ref_top[i] >= 0 && cand_rank >= 0) { + overlap++; + const int rank_delta = abs(cand_rank - i); + if (rank_delta > max_rank_delta) max_rank_delta = rank_delta; + } + if (i < TEST_MPP_EQ_TOP5 && + ref_top[i] >= 0 && + test_topk_contains(cand_top, TEST_MPP_EQ_TOP5, ref_top[i])) { + top5_overlap++; + } + } + + double sumsq = 0.0; + float max_abs = 0.0f; + int nonfinite = 0; + int delta_ids[TEST_MPP_EQ_DELTAS]; + float delta_ref[TEST_MPP_EQ_DELTAS]; + float delta_cand[TEST_MPP_EQ_DELTAS]; + float delta_abs[TEST_MPP_EQ_DELTAS]; + for (int i = 0; i < TEST_MPP_EQ_DELTAS; i++) { + delta_ids[i] = -1; + delta_ref[i] = 0.0f; + delta_cand[i] = 0.0f; + delta_abs[i] = 0.0f; + } + + for (int i = 0; i < tc->vocab_size; i++) { + if (!isfinite(tc->ref_logits[i]) || !isfinite(cand_logits[i])) { + nonfinite++; + continue; + } + const float delta = cand_logits[i] - tc->ref_logits[i]; + const float abs_delta = fabsf(delta); + if (abs_delta > max_abs) max_abs = abs_delta; + sumsq += (double)delta * (double)delta; + test_note_delta(delta_ids, delta_ref, delta_cand, delta_abs, + (int)i, tc->ref_logits[i], cand_logits[i]); + } + + const float rms = (float)sqrt(sumsq / (double)tc->vocab_size); + const float top_abs = test_top_union_max_abs(tc->ref_logits, cand_logits, + ref_top, cand_top, TEST_MPP_EQ_TOPK); + const bool same_top1 = ref_top[0] >= 0 && ref_top[0] == cand_top[0]; + test_mpp_eq_result result = { + .ref_top1 = ref_top[0], + .cand_top1 = cand_top[0], + .overlap = overlap, + .top5_overlap = top5_overlap, + .max_rank_delta = max_rank_delta, + .nonfinite = nonfinite, + .rms = rms, + .max_abs = max_abs, + .top20_max_abs = top_abs, + .same_top1 = same_top1, + .pass = nonfinite == 0 && same_top1, + }; + + fprintf(stderr, + "ds4-test: MPP equivalence %s top1 ref=%d cand=%d top5_overlap=%d/%d overlap=%d/%d max_rank_delta=%d rms=%g max_abs=%g top20_max_abs=%g\n", + tc->id, ref_top[0], cand_top[0], + top5_overlap, TEST_MPP_EQ_TOP5, + overlap, TEST_MPP_EQ_TOPK, + max_rank_delta, rms, max_abs, top_abs); + fprintf(stderr, "ds4-test: MPP equivalence %s largest deltas:", tc->id); + for (int i = 0; i < TEST_MPP_EQ_DELTAS && delta_ids[i] >= 0; i++) { + fprintf(stderr, " id=%d ref=%g cand=%g abs=%g", + delta_ids[i], delta_ref[i], delta_cand[i], delta_abs[i]); + } + fputc('\n', stderr); + + if (assert_thresholds) { + TEST_ASSERT(nonfinite == 0); + TEST_ASSERT(same_top1); + } + return result; +} + +static bool test_mpp_capture(ds4_engine *engine, const test_mpp_eq_case *tc, + float *logits, int *gen, int *gen_len) { + ds4_session *session = NULL; + TEST_ASSERT(ds4_session_create(&session, engine, tc->ctx) == 0); + if (!session) return false; + + char err[160]; + bool ok = ds4_session_sync(session, &tc->prompt, err, sizeof(err)) == 0; + TEST_ASSERT(ok); + if (ok) { + ok = ds4_session_copy_logits(session, logits, tc->vocab_size) == tc->vocab_size; + TEST_ASSERT(ok); + } + + int n = 0; + while (ok && n < tc->gen_steps) { + const int token = ds4_session_argmax(session); + gen[n++] = token; + if (n < tc->gen_steps && ds4_session_eval(session, token, err, sizeof(err)) != 0) { + ok = false; + TEST_ASSERT(false); + } + } + *gen_len = n; + + ds4_session_free(session); + return ok; +} + +static bool test_mpp_eq_case_selected(const char *id) { + const char *filter = getenv("DS4_TEST_MPP_EQ_CASE"); + if (!filter || !filter[0]) return true; + + char buf[256]; + snprintf(buf, sizeof(buf), "%s", filter); + for (char *tok = strtok(buf, ","); tok; tok = strtok(NULL, ",")) { + tok = test_trim_line(tok); + if (tok[0] && strstr(id, tok)) return true; + } + return false; +} + +static int test_load_mpp_cases(ds4_engine *engine, test_mpp_eq_case *cases, int cap) { + const char *path = getenv("DS4_TEST_VECTOR_FILE"); + if (!path || !path[0]) path = "tests/test-vectors/official.vec"; + FILE *fp = fopen(path, "rb"); + TEST_ASSERT(fp != NULL); + if (!fp) return 0; + + int ncase = 0; + test_vec_case vc; + while (ncase < cap && test_read_vector_case(fp, &vc)) { + if (!test_fill_vector_case(fp, &vc)) break; + if (!test_mpp_eq_case_selected(vc.id)) continue; + char *prompt_text = test_read_file(vc.prompt_path); + TEST_ASSERT(prompt_text != NULL); + if (!prompt_text) continue; + + test_mpp_eq_case *tc = &cases[ncase++]; + snprintf(tc->id, sizeof(tc->id), "%s", vc.id); + tc->ctx = vc.ctx; + tc->vocab_size = ds4_engine_vocab_size(engine); + tc->gen_steps = vc.nsteps < TEST_VEC_MAX_STEPS ? vc.nsteps : TEST_VEC_MAX_STEPS; + ds4_encode_chat_prompt(engine, "", prompt_text, DS4_THINK_NONE, &tc->prompt); + free(prompt_text); + TEST_ASSERT(tc->prompt.len > 0); + } + fclose(fp); + return ncase; +} + +static ds4_engine *test_open_mpp_engine(ds4_mpp_mode mode) { + ds4_engine *engine = NULL; + ds4_engine_options opt = { + .model_path = test_model_path(), + .backend = DS4_BACKEND_METAL, + .mpp_mode = mode, + }; + TEST_ASSERT(ds4_engine_open(&engine, &opt) == 0); + return engine; +} + +static void test_mpp_summary_init(test_mpp_eq_summary *summary, const char *label) { + memset(summary, 0, sizeof(*summary)); + summary->label = label; + summary->min_overlap = TEST_MPP_EQ_TOPK; + summary->min_top5_overlap = TEST_MPP_EQ_TOP5; +} + +static void test_mpp_summary_note_logits(test_mpp_eq_summary *summary, + const test_mpp_eq_result *result) { + if (!result->pass) summary->logits_failures++; + if (!result->same_top1) summary->top1_mismatches++; + if (result->overlap < summary->min_overlap) summary->min_overlap = result->overlap; + if (result->top5_overlap < summary->min_top5_overlap) { + summary->min_top5_overlap = result->top5_overlap; + } + if (result->max_rank_delta > summary->worst_rank_delta) { + summary->worst_rank_delta = result->max_rank_delta; + } + if (result->rms > summary->worst_rms) summary->worst_rms = result->rms; + if (result->max_abs > summary->worst_max_abs) summary->worst_max_abs = result->max_abs; + if (result->top20_max_abs > summary->worst_top20_max_abs) { + summary->worst_top20_max_abs = result->top20_max_abs; + } +} + +static void test_mpp_summary_print(const test_mpp_eq_summary *summary) { + fprintf(stderr, + "ds4-test: MPP summary route=%s cases=%d capture_fail=%d logits_fail=%d greedy_fail=%d top1_mismatch=%d min_top5_overlap=%d/%d min_overlap=%d/%d worst_rank_delta=%d worst_rms=%g worst_max_abs=%g worst_top20_max_abs=%g\n", + summary->label, + summary->cases, + summary->capture_failures, + summary->logits_failures, + summary->greedy_failures, + summary->top1_mismatches, + summary->min_top5_overlap, + TEST_MPP_EQ_TOP5, + summary->min_overlap, + TEST_MPP_EQ_TOPK, + summary->worst_rank_delta, + summary->worst_rms, + summary->worst_max_abs, + summary->worst_top20_max_abs); +} + +static void test_run_mpp_candidate(const char *label, + ds4_mpp_mode mode, + test_mpp_eq_case *cases, + int ncase) { + fprintf(stderr, "ds4-test: MPP equivalence candidate route=%s mode=%s\n", + label, ds4_mpp_mode_name(mode)); + test_mpp_eq_summary summary; + test_mpp_summary_init(&summary, label); + ds4_engine *cand_engine = test_open_mpp_engine(mode); + if (cand_engine) { + const int vocab_size = ncase > 0 ? cases[0].vocab_size : 0; + float *cand_logits = malloc((size_t)vocab_size * sizeof(cand_logits[0])); + TEST_ASSERT(cand_logits != NULL); + if (cand_logits) { + for (int i = 0; i < ncase; i++) { + test_mpp_eq_case *tc = &cases[i]; + if (!tc->ref_logits) continue; + int cand_gen[TEST_VEC_MAX_STEPS] = {0}; + int cand_gen_len = 0; + if (!test_mpp_capture(cand_engine, tc, cand_logits, cand_gen, &cand_gen_len)) { + summary.capture_failures++; + continue; + } + summary.cases++; + test_mpp_eq_result result = test_compare_mpp_logits(tc, cand_logits, true); + test_mpp_summary_note_logits(&summary, &result); + TEST_ASSERT(cand_gen_len == tc->ref_gen_len); + if (cand_gen_len != tc->ref_gen_len) summary.greedy_failures++; + for (int j = 0; j < tc->ref_gen_len && j < cand_gen_len; j++) { + if (cand_gen[j] != tc->ref_gen[j]) { + fprintf(stderr, + "ds4-test: MPP equivalence %s greedy token mismatch step=%d ref=%d cand=%d\n", + tc->id, j, tc->ref_gen[j], cand_gen[j]); + summary.greedy_failures++; + } + TEST_ASSERT(cand_gen[j] == tc->ref_gen[j]); + } + } + free(cand_logits); + } + ds4_engine_close(cand_engine); + } + test_mpp_summary_print(&summary); +} + +static const char *const test_mpp_route_envs[] = { + "DS4_METAL_MPP_ENABLE", + "DS4_METAL_MPP_DISABLE", + "DS4_METAL_MPP_FAST", + "DS4_METAL_MPP_DIRECT_RHS", + "DS4_METAL_MPP_Q8_0_ENABLE", + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_Q8_0_DIRECT_RHS", + "DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE", + "DS4_METAL_MPP_Q8_0_FILTER", + "DS4_METAL_MPP_Q8_0_TILE_N", + "DS4_METAL_MPP_F16_ENABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_F16_DIRECT_RHS", + "DS4_METAL_MPP_F16_WIDE", + "DS4_METAL_MPP_F16_PAIR", + "DS4_METAL_MPP_ATTN_OUT_ENABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS", + "DS4_METAL_MPP_ATTN_OUT_FILTER", + "DS4_METAL_MPP_ATTN_OUT_TILE_N", + "DS4_METAL_MPP_MOE_ENABLE", + "DS4_METAL_MPP_MOE_DISABLE", + "DS4_METAL_MPP_MOE_FILTER", + "DS4_METAL_MPP_MOE_TILE_N", + "DS4_METAL_MPP_MOE_FAST_LAYOUT", + "DS4_METAL_MPP_MOE_PAIR_GATE_UP", + "DS4_METAL_MPP_MOE_START_LAYER", + "DS4_METAL_MPP_MOE_GATE_ENABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE", + "DS4_METAL_MPP_MOE_GATE_FILTER", + "DS4_METAL_MPP_MOE_GATE_START_LAYER", + "DS4_METAL_MPP_MOE_UP_ENABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE", + "DS4_METAL_MPP_MOE_UP_FILTER", + "DS4_METAL_MPP_MOE_UP_START_LAYER", + "DS4_METAL_MPP_MOE_DOWN_ENABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE", + "DS4_METAL_MPP_MOE_DOWN_FILTER", + "DS4_METAL_MPP_MOE_DOWN_START_LAYER", +}; + +typedef struct { + const char *name; + char *value; + bool had_value; +} test_mpp_saved_env; + +static void test_mpp_save_envs(test_mpp_saved_env *saved, int n) { + for (int i = 0; i < n; i++) { + saved[i].name = test_mpp_route_envs[i]; + const char *v = getenv(saved[i].name); + saved[i].had_value = v != NULL; + saved[i].value = v ? strdup(v) : NULL; + } +} + +static void test_mpp_restore_envs(test_mpp_saved_env *saved, int n) { + for (int i = 0; i < n; i++) { + if (saved[i].had_value) { + setenv(saved[i].name, saved[i].value ? saved[i].value : "", 1); + } else { + unsetenv(saved[i].name); + } + free(saved[i].value); + saved[i].value = NULL; + } +} + +static void test_mpp_clear_route_envs(void) { + for (size_t i = 0; i < sizeof(test_mpp_route_envs) / sizeof(test_mpp_route_envs[0]); i++) { + unsetenv(test_mpp_route_envs[i]); + } +} + +typedef struct { + const char *label; + ds4_mpp_mode mode; + const char *set_envs[8]; +} test_mpp_matrix_config; + +static void test_mpp_apply_matrix_config(const test_mpp_matrix_config *cfg) { + test_mpp_clear_route_envs(); + for (int i = 0; cfg->set_envs[i]; i++) { + setenv(cfg->set_envs[i], "1", 1); + } +} + +static void test_run_mpp_matrix(test_mpp_eq_case *cases, int ncase) { + const test_mpp_matrix_config configs[] = { + { "auto", DS4_MPP_AUTO, { NULL } }, + { "fast_profile", DS4_MPP_AUTO, { + "DS4_METAL_MPP_FAST", + NULL + } }, + { "q8_only", DS4_MPP_ON, { + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_DISABLE", + NULL + } }, + { "attn_out_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_MOE_DISABLE", + NULL + } }, + { "moe_gate_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE", + NULL + } }, + { "moe_up_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE", + "DS4_METAL_MPP_MOE_DOWN_DISABLE", + NULL + } }, + { "moe_down_only", DS4_MPP_ON, { + "DS4_METAL_MPP_Q8_0_DISABLE", + "DS4_METAL_MPP_F16_DISABLE", + "DS4_METAL_MPP_ATTN_OUT_DISABLE", + "DS4_METAL_MPP_MOE_GATE_DISABLE", + "DS4_METAL_MPP_MOE_UP_DISABLE", + NULL + } }, + { "full_forced", DS4_MPP_ON, { NULL } }, + }; + + test_mpp_saved_env saved[sizeof(test_mpp_route_envs) / sizeof(test_mpp_route_envs[0])]; + test_mpp_save_envs(saved, (int)(sizeof(saved) / sizeof(saved[0]))); + for (size_t i = 0; i < sizeof(configs) / sizeof(configs[0]); i++) { + test_mpp_apply_matrix_config(&configs[i]); + test_run_mpp_candidate(configs[i].label, configs[i].mode, cases, ncase); + } + test_mpp_restore_envs(saved, (int)(sizeof(saved) / sizeof(saved[0]))); +} + +static void test_metal_mpp_equivalence(void) { + test_close_engines(); + + test_mpp_eq_case cases[TEST_MPP_EQ_MAX_CASES]; + memset(cases, 0, sizeof(cases)); + + ds4_engine *ref_engine = test_open_mpp_engine(DS4_MPP_OFF); + if (!ref_engine) return; + + const int ncase = test_load_mpp_cases(ref_engine, cases, TEST_MPP_EQ_MAX_CASES); + TEST_ASSERT(ncase > 0); + for (int i = 0; i < ncase; i++) { + test_mpp_eq_case *tc = &cases[i]; + tc->ref_logits = malloc((size_t)tc->vocab_size * sizeof(tc->ref_logits[0])); + TEST_ASSERT(tc->ref_logits != NULL); + if (!tc->ref_logits) continue; + TEST_ASSERT(test_mpp_capture(ref_engine, tc, + tc->ref_logits, + tc->ref_gen, + &tc->ref_gen_len)); + } + ds4_engine_close(ref_engine); + + if (getenv("DS4_TEST_MPP_EQ_MATRIX") != NULL) { + test_run_mpp_matrix(cases, ncase); + } else { + const bool force_on = getenv("DS4_TEST_MPP_EQ_FORCE_ON") != NULL; + test_run_mpp_candidate(force_on ? "forced" : "auto", + force_on ? DS4_MPP_ON : DS4_MPP_AUTO, + cases, + ncase); + } + + for (int i = 0; i < ncase; i++) test_mpp_eq_case_free(&cases[i]); +} + static const char *test_tool_call_request_json(void) { return "{" @@ -774,6 +1343,7 @@ static const ds4_test_entry test_entries[] = { {"--tool-call-quality", "tool-call-quality", "model emits valid DSML tool calls", test_tool_call_quality}, {"--logprob-vectors", "logprob-vectors", "official API top-logprob vector comparison", test_official_logprob_vectors}, {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_kernel_group}, + {"--metal-mpp-equivalence", "metal-mpp-equivalence", "Metal MPP off/on prompt-logit and greedy equivalence", test_metal_mpp_equivalence}, #endif {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, }; @@ -794,6 +1364,9 @@ static void test_print_help(const char *prog) { puts(" DS4_TEST_MODEL=FILE Model path. Default: ds4flash.gguf"); puts(" DS4_TEST_LONG_PROMPT=FILE Rendered long-context story fact prompt."); puts(" DS4_TEST_VECTOR_FILE=FILE Simple official-vector fixture."); + puts(" DS4_TEST_MPP_EQ_CASE=NAME Run only MPP equivalence cases whose id contains NAME."); + puts(" DS4_TEST_MPP_EQ_FORCE_ON=1 Compare --mpp off against forced --mpp on instead of auto."); + puts(" DS4_TEST_MPP_EQ_MATRIX=1 Run auto and isolated forced MPP route rows."); } static const ds4_test_entry *test_find_entry(const char *arg) { From e823fe2a26faf11f5a047044187c792e00ae9cae Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Mon, 11 May 2026 18:25:09 +0200 Subject: [PATCH 03/16] Tune Metal MPP defaults and thinking checkpoints --- README.md | 71 +++++++++++++++++++++++++---------------------------- ds4_metal.m | 24 ++++++++++-------- 2 files changed, 47 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 3667471d..dbe63e9e 100644 --- a/README.md +++ b/README.md @@ -231,38 +231,37 @@ remain opt-in diagnostics. The environment controls by mere presence. Passing `--quality` also disables MPP routes so strict/debug runs stay on the legacy Metal kernels. Set `DS4_METAL_MPP_FAST=1` to opt into the current same-top1/same-greedy fast profile: it widens Q8_0 and -attention-output MPP to all layers, enables Q8_0 partial token tiles, and uses -earlier routed-MoE MPP windows. This profile is not the default because its -whole-vocab and top-k drift are much larger than the correctness-first auto -profile. -Set `DS4_METAL_MPP_DIRECT_RHS=1` only for diagnostics of the first-PR MPP -direct-RHS tensor layout; it is not part of the correctness-first default. Q8_0 -and attention-output direct-RHS diagnostics support both 32-token and 64-token -MPP tiles, so they can be combined with `DS4_METAL_MPP_Q8_0_TILE_N=64` and -`DS4_METAL_MPP_ATTN_OUT_TILE_N=64` for M5 throughput experiments. The +attention-output MPP to all layers and uses earlier routed-MoE MPP windows. +This profile is not the default because its whole-vocab and top-k drift are +much larger than the correctness-first auto profile. +The default safe-window policy uses the direct-RHS tensor layout for MPP routes; +set `DS4_METAL_MPP_DIRECT_RHS=0` to compare against the older staged-RHS +layout. Q8_0 and attention-output direct-RHS routes support both 32-token and +64-token MPP tiles. Auto defaults those two routes to 64-token tiles for M5 +throughput; set `DS4_METAL_MPP_Q8_0_TILE_N=32` or +`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to compare the narrower layout. The route-specific `DS4_METAL_MPP_Q8_0_DIRECT_RHS=1`, `DS4_METAL_MPP_F16_DIRECT_RHS=1`, and -`DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS=1` switches isolate that diagnostic layout -without turning on every direct-RHS route at once. +`DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS=1` switches isolate that layout without +turning on every direct-RHS route at once when the global +`DS4_METAL_MPP_DIRECT_RHS=0` override is set. The Q8_0 prefill MPP route can be isolated with `DS4_METAL_MPP_Q8_0_ENABLE=1` or `DS4_METAL_MPP_Q8_0_DISABLE=1`. It only affects prompt batches larger than eight tokens and is limited by default to the late full-model-safe layer window 38..42, plus the `attn_q_b` projection in -layers 32..37. It uses only full 32-token tiles by default and falls back to the -legacy kernel for partial token tiles or when the Metal 4 tensor path is -unavailable. Set -`DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=1` to reproduce or localize partial-tile -drift while debugging. Set `DS4_METAL_MPP_Q8_0_FILTER=all` to reproduce the +layers 32..37. It uses 64-token tiles by default, accepts partial token tails, +and falls back to the legacy kernel when the Metal 4 tensor path is unavailable. +Set `DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=0` to force the old partial-tail +fallback while debugging. Set `DS4_METAL_MPP_Q8_0_FILTER=all` to reproduce the unsafe all-layer Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=late_safe` to request the default safe window explicitly, or `DS4_METAL_MPP_Q8_0_FILTER=` to force named full-graph Q8 modules such as `attn_q_a`, `attn_kv`, `attn_q_b`, `attn_out`, `shared_gate`, `shared_up`, or `shared_down`. Use `@layer=A..B` to test one module family only in a layer window, for -example `shared_up@layer=30..37`. Set -`DS4_METAL_MPP_Q8_0_TILE_N=64` to test the experimental wider MPP token tile -for performance against the default `32`. The isolated +example `shared_up@layer=30..37`. Set `DS4_METAL_MPP_Q8_0_TILE_N=32` to +compare against the narrower MPP token tile. The isolated `./ds4_test --metal-kernels` regression reports small/medium/model-ish kernel deltas; the full-model `./ds4_test --metal-mpp-equivalence` diagnostic compares default auto against @@ -296,24 +295,19 @@ layers can amplify small local differences through normalization/attention enough to fail prompt-logit equivalence. The `attn_q_b` 32..37 extension is kept because it is query-side only for full prompt tiles in the current validation path, passes prompt-logit equivalence, and improves prefill -throughput. The F16 compressor route did not introduce measurable drift in the -current prompt set. +throughput. The current auto policy also uses Q8_0 partial tails, direct-RHS MPP +inputs, and 64-token tiles for Q8_0 and attention-output low projections; on +M5 Max the long-code audit prompt sampled around `395 t/s` in a run where MPP +off sampled around `354 t/s`, with visible desktop-load variance. The F16 +compressor route did not introduce measurable drift in the current prompt set. The `DS4_METAL_MPP_FAST=1` profile is the measured high-throughput diagnostic profile under the relaxed same-top1/same-greedy gate. In the current prompt suite it keeps top-1 and greedy continuations stable, but reports much larger distribution drift than auto (`worst_rms ~= 0.761`, -`worst_top20_max_abs ~= 2.28`, minimum top-20 overlap `18/20`). On the -long-code prefill benchmark it sampled around `360 t/s` in the same window -where auto sampled around `318 t/s`; benchmark variance is high when the -desktop is active. The more aggressive direct-RHS 64-token diagnostic -(`DS4_METAL_MPP_FAST=1 DS4_METAL_MPP_DIRECT_RHS=1 -DS4_METAL_MPP_Q8_0_TILE_N=64 DS4_METAL_MPP_ATTN_OUT_TILE_N=64`) passed the -relaxed top-1/greedy gate and `--logprob-vectors`, and in Automatic power mode -sampled around `324 t/s` versus `289 t/s` for auto in the same short benchmark -window. It remains diagnostic-only because its full-suite drift is higher -(`worst_rms ~= 0.846`, `worst_top20_max_abs ~= 2.07`, minimum top-20 overlap -`16/20`). +`worst_top20_max_abs ~= 2.28`, minimum top-20 overlap `18/20`). It remains +diagnostic-only because it widens the route windows that produce the largest +full-suite drift. The routed-MoE MPP projections are staged when forced and are limited to a late full-model-safe layer window by default: gate/down start at layer 28, and @@ -347,17 +341,18 @@ outputs are summed with a single Metal kernel instead of five chained add passes. Set `DS4_METAL_MOE_SUM6_DISABLE=1` to compare or temporarily disable that fused sum route. -The attention-output low-projection MPP route applies to full 32-token tiles -in the default safe window, falling back to the existing indexed simdgroup -kernel for partial tiles. Attention-output MPP is limited to the measured -full-model-safe layer window 32..42 by default. Set +The attention-output low-projection MPP route applies to full 32-token multiples +in the default safe window, using a 64-token MPP tile by default and falling +back to the existing indexed simdgroup kernel for shorter or non-32-multiple +tails. Attention-output MPP is limited to the measured full-model-safe layer +window 32..42 by default. Set `DS4_METAL_MPP_ATTN_OUT_ENABLE=1` or `DS4_METAL_MPP_ATTN_OUT_DISABLE=1` to isolate this route. Set `DS4_METAL_MPP_ATTN_OUT_FILTER=all`, `late_safe`, `none`, or a comma-separated list of full-graph context substrings such as `layer=42` to localize full-model-safe layer windows. Layer filters are exact, and `layer=A..B` matches an inclusive range. Set -`DS4_METAL_MPP_ATTN_OUT_TILE_N=64` to test the experimental wider MPP token -tile for performance against the default `32`. The all-layer +`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to compare against the narrower MPP token +tile. The all-layer attention-output MPP route still fails long-prompt full-model equivalence despite per-layer low-projection differences below the current kernel target. The ratio-2 F16 compressor route can similarly be controlled with diff --git a/ds4_metal.m b/ds4_metal.m index 741dc515..ec863e0b 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -1080,33 +1080,35 @@ static int ds4_gpu_use_mpp_q8_0_matmul(void) { static int ds4_gpu_mpp_q8_0_partial_tiles_enabled(void) { if (ds4_gpu_mpp_fast_profile()) return 1; - return ds4_gpu_env_bool("DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE") > 0; + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE"); + if (enabled >= 0) return enabled > 0; + return 1; } -static uint32_t ds4_gpu_mpp_tile_n_env(const char *name) { +static uint32_t ds4_gpu_mpp_tile_n_env(const char *name, uint32_t fallback) { const char *env = getenv(name); - if (!env || !env[0]) return 32; + if (!env || !env[0]) return fallback; char *end = NULL; long v = strtol(env, &end, 10); while (end && isspace((unsigned char)*end)) end++; if (end && *end == '\0' && v == 64) return 64; if (end && *end == '\0' && v == 32) return 32; fprintf(stderr, - "ds4: invalid %s=%s; expected 32 or 64, using 32\n", - name, env); - return 32; + "ds4: invalid %s=%s; expected 32 or 64, using %u\n", + name, env, fallback); + return fallback; } static uint32_t ds4_gpu_mpp_q8_0_tile_n(void) { - return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_Q8_0_TILE_N"); + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_Q8_0_TILE_N", 64); } static uint32_t ds4_gpu_mpp_attn_out_tile_n(void) { - return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_ATTN_OUT_TILE_N"); + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_ATTN_OUT_TILE_N", 64); } static uint32_t ds4_gpu_mpp_moe_tile_n(void) { - return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_MOE_TILE_N"); + return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_MOE_TILE_N", 32); } static int ds4_gpu_mpp_moe_fast_layout(void) { @@ -1118,7 +1120,9 @@ static int ds4_gpu_mpp_moe_pair_gate_up(void) { } static int ds4_gpu_mpp_direct_rhs(void) { - return ds4_gpu_env_bool("DS4_METAL_MPP_DIRECT_RHS") > 0; + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_DIRECT_RHS"); + if (enabled >= 0) return enabled > 0; + return 1; } static int ds4_gpu_mpp_q8_0_direct_rhs(void) { From f5363ab14c4794487d2ac54a0af2d6aab39c1970 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Tue, 12 May 2026 00:36:51 +0200 Subject: [PATCH 04/16] Improve Metal MPP prefill throughput Raise the default Metal prefill chunk to 4096 and reuse the range-capable layer-major prefill graph for chunked ranges. Enable the guarded Q8_0 attn_q_b MPP route for <=2048-token prompt batches, dynamic Q8_0 tile width, the routed-MoE fast layout from layer 0, and the RB16 indexed decode path. M5 Max post-patch ds4-bench profile with 64 generated tokens: prompt 443/459/522/486/465 t/s and generation 38.6/38.2/37.6/34.0/33.6 t/s at 0.5k/1k/2k/4k/8k. Tests: make all ds4_test; make test; git diff --check. --- README.md | 118 ++++++++++------ ds4.c | 303 ++++++++++++++++++++---------------------- ds4_metal.m | 66 ++++++--- metal/dsv4_misc.metal | 133 +++++++++++++++++- metal/moe.metal | 5 +- 5 files changed, 402 insertions(+), 223 deletions(-) diff --git a/README.md b/README.md index dbe63e9e..c769abcd 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,15 @@ exponential sweeps. Output is CSV with one row per frontier: latest prefill interval tokens/sec, generation tokens/sec at that frontier, and `kvcache_bytes`. +Sessions prefill long prompts in 4096-token chunks by default. Set +`DS4_METAL_PREFILL_CHUNK=N` to compare another chunk size, for example `2048` +to reduce transient memory, or `DS4_METAL_PREFILL_CHUNK=0` to prefill a prompt +as one whole batch when memory allows. Changing the chunk changes the KV +checkpoint shape, so compare it as an explicit run configuration. +Chunked Metal prefill reuses the same range-capable layer-major graph for each +chunk, preserving absolute compressor/indexer boundaries while avoiding the old +per-layer chunk dispatch path. + ## Metal 4 and M5 Neural Accelerators The current production path is still hand-written Metal compute kernels over @@ -220,26 +229,29 @@ tensor matmul probe before it lets the main Metal shader source see `DS4_METAL_HAS_TENSOR`, so unsupported SDK/device combinations fall back to the legacy kernels. -MPP policy is explicit and correctness-first. Use `--mpp auto` for the default +MPP policy is explicit and guarded. Use `--mpp auto` for the default route policy, `--mpp on` to force MPP routes where the Metal 4 tensor path is available, and `--mpp off` for the legacy Metal reference path. Auto currently -enables only the validated late-layer safe windows that pass full-model -equivalence and clear the benchmark gate; early-layer and all-layer MPP routes -remain opt-in diagnostics. The environment controls +keeps attention-output MPP in the validated late-layer window, extends the +Q8_0 `attn_q_b` projection for small prompt batches, and runs routed-MoE MPP +from layer 0 for prefill throughput while preserving same-top1/same-greedy +agreement. Unguarded Q8_0 and attention-output all-layer MPP routes remain +opt-in diagnostics. The environment controls `DS4_METAL_MPP_ENABLE` and `DS4_METAL_MPP_DISABLE` accept `1/true/yes/on` and `0/false/no/off`; `DS4_METAL_MPP_ENABLE=0` disables MPP instead of enabling it by mere presence. Passing `--quality` also disables MPP routes so strict/debug runs stay on the legacy Metal kernels. Set `DS4_METAL_MPP_FAST=1` to opt into the current same-top1/same-greedy fast profile: it widens Q8_0 and -attention-output MPP to all layers and uses earlier routed-MoE MPP windows. -This profile is not the default because its whole-vocab and top-k drift are -much larger than the correctness-first auto profile. +attention-output MPP to all layers while keeping the routed-MoE all-layer +default. This profile is not the default because its top-k overlap is weaker +than auto in the current full-model suite. The default safe-window policy uses the direct-RHS tensor layout for MPP routes; set `DS4_METAL_MPP_DIRECT_RHS=0` to compare against the older staged-RHS layout. Q8_0 and attention-output direct-RHS routes support both 32-token and -64-token MPP tiles. Auto defaults those two routes to 64-token tiles for M5 -throughput; set `DS4_METAL_MPP_Q8_0_TILE_N=32` or -`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to compare the narrower layout. The +64-token MPP tiles. Auto defaults attention-output to 64-token tiles, while +Q8_0 uses 64-token tiles below 4096-token batches and 32-token tiles for larger +prompt batches on M5. Set `DS4_METAL_MPP_Q8_0_TILE_N=32` or +`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to force the narrower layout. The route-specific `DS4_METAL_MPP_Q8_0_DIRECT_RHS=1`, `DS4_METAL_MPP_F16_DIRECT_RHS=1`, and `DS4_METAL_MPP_ATTN_OUT_DIRECT_RHS=1` switches isolate that layout without @@ -248,14 +260,16 @@ turning on every direct-RHS route at once when the global The Q8_0 prefill MPP route can be isolated with `DS4_METAL_MPP_Q8_0_ENABLE=1` or `DS4_METAL_MPP_Q8_0_DISABLE=1`. It only -affects prompt batches larger than eight tokens and is limited by default to -the late full-model-safe layer window 38..42, plus the `attn_q_b` projection in -layers 32..37. It uses 64-token tiles by default, accepts partial token tails, -and falls back to the legacy kernel when the Metal 4 tensor path is unavailable. +affects prompt batches larger than eight tokens. By default, batches up to 2048 +tokens use MPP for `attn_q_b` across layers, while larger batches use the +late full-model-safe layer window 38..42 plus `attn_q_b` in layers 32..37. It +uses 64-token tiles below 4096-token batches and 32-token tiles for larger +prompt batches on M5, accepts partial token tails, and falls back to the legacy +kernel when the Metal 4 tensor path is unavailable. Set `DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=0` to force the old partial-tail fallback while debugging. Set `DS4_METAL_MPP_Q8_0_FILTER=all` to reproduce the unsafe all-layer Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=late_safe` to request the -default safe window explicitly, or +older conservative late window explicitly, or `DS4_METAL_MPP_Q8_0_FILTER=` to force named full-graph Q8 modules such as `attn_q_a`, `attn_kv`, `attn_q_b`, `attn_out`, `shared_gate`, `shared_up`, or `shared_down`. Use @@ -282,36 +296,44 @@ first comparison that exceeds the kernel target, including module/layer context, shape, max absolute error, RMS, and the largest element deltas. Set `DS4_METAL_MPP_COMPARE_VERBOSE=1` to print passing comparisons as well. -Current MPP route status is intentionally conservative: `auto` enables Q8_0 -prefill, F16 compressor, attention-output low projection, and routed-MoE MPP -only in the full-model-safe windows. Attention-output low projection now uses -layers 32..42 by default, while Q8_0 keeps one narrower `attn_q_b` extension -for layers 32..37. The Q8_0 and attention-output low MPP +Current MPP route status balances drift with prefill throughput: `auto` enables +Q8_0 prefill, F16 compressor, attention-output low projection, and routed-MoE +MPP. Attention-output low projection now uses layers 32..42 by default, while +Q8_0 uses `attn_q_b` across layers for <=2048-token prompt batches and keeps +the narrower `attn_q_b` 32..37 plus all-Q8 38..42 window for larger batches. +Routed-MoE MPP now covers gate/up/down from layer 0 by default to favor prefill +throughput on M5-class systems; it still preserves greedy agreement in the MPP +equivalence suite, but it carries larger logit drift than the previous +layer-20/22 conservative window. The current auto suite reports +same-top1/same-greedy agreement with minimum top-5 overlap `4/5`, minimum +top-20 overlap `17/20`, `worst_rms ~= 0.942`, and +`worst_top20_max_abs ~= 3.06`. The Q8_0 and attention-output low MPP kernels stage activation tiles through half to match the legacy Metal matmul input path, which brings the isolated model-ish Q8_0 regression under the strict kernel target and removes the first attention-output comparator breach. Most Q8_0 projection families stay restricted to layers 38..42 because earlier layers can amplify small local differences through normalization/attention -enough to fail prompt-logit equivalence. The `attn_q_b` 32..37 extension is -kept because it is query-side only for full prompt tiles in the current -validation path, passes prompt-logit equivalence, and improves prefill -throughput. The current auto policy also uses Q8_0 partial tails, direct-RHS MPP -inputs, and 64-token tiles for Q8_0 and attention-output low projections; on -M5 Max the long-code audit prompt sampled around `395 t/s` in a run where MPP -off sampled around `354 t/s`, with visible desktop-load variance. The F16 +enough to fail long-context generation. The guarded `attn_q_b` extension is +kept because it is query-side only, passes prompt-logit and long-context gates +when limited to <=2048-token batches, and improves prefill throughput. The +current auto policy also uses Q8_0 partial tails, direct-RHS MPP inputs, dynamic +Q8_0 tile width, and 64-token tiles for attention-output low projections. In a +local M5 Max `ds4-bench` sweep with 64 generated tokens, auto sampled about +`443/459/522/486/465` prompt tokens/sec and +`38.6/38.2/37.6/34.0/33.6` generation tokens/sec at the +`0.5k/1k/2k/4k/8k` frontiers, with visible desktop-load variance. The F16 compressor route did not introduce measurable drift in the current prompt set. The `DS4_METAL_MPP_FAST=1` profile is the measured high-throughput diagnostic profile under the relaxed same-top1/same-greedy gate. In the current prompt -suite it keeps top-1 and greedy continuations stable, but reports much larger -distribution drift than auto (`worst_rms ~= 0.761`, -`worst_top20_max_abs ~= 2.28`, minimum top-20 overlap `18/20`). It remains -diagnostic-only because it widens the route windows that produce the largest -full-suite drift. - -The routed-MoE MPP projections are staged when forced and are limited to a -late full-model-safe layer window by default: gate/down start at layer 28, and -up starts at layer 30. For route isolation, use +suite it keeps top-1 and greedy continuations stable, but reports weaker top-k +overlap than auto (`worst_rms ~= 0.951`, `worst_top20_max_abs ~= 4.03`, +minimum top-20 overlap `16/20`). It remains diagnostic-only because it widens +the Q8_0 and attention-output route windows that produce the largest full-suite +drift. + +The routed-MoE MPP projections are enabled from layer 0 by default for prefill +speed. For route isolation, use `DS4_METAL_MPP_MOE_GATE_ENABLE/DISABLE`, `DS4_METAL_MPP_MOE_UP_ENABLE/DISABLE`, and `DS4_METAL_MPP_MOE_DOWN_ENABLE/DISABLE`; `DS4_METAL_MPP_MOE_DISABLE=1` @@ -324,14 +346,15 @@ Use `layer=N` for an exact layer match or `layer=A..B` for an inclusive layer range when testing sparse MPP windows. The same `@layer=A..B` syntax can restrict a context substring to a layer window. Set `DS4_METAL_MPP_MOE_TILE_N=64` to test the experimental wider routed-MoE -MPP token tile for performance against the default `32`. Set -`DS4_METAL_MPP_MOE_FAST_LAYOUT=1` to test the old first-PR routed-MoE MPP -threadgroup tensor layout as an explicit performance diagnostic. Set +MPP token tile for performance against the default `32`. The routed-MoE MPP +path uses the faster first-PR threadgroup tensor layout by default inside the +active routed-MoE windows; set `DS4_METAL_MPP_MOE_FAST_LAYOUT=0` to compare +against the newer staged layout. Set `DS4_METAL_MPP_MOE_START_LAYER=N`, or the route-specific `DS4_METAL_MPP_MOE_GATE_START_LAYER`, `DS4_METAL_MPP_MOE_UP_START_LAYER`, and -`DS4_METAL_MPP_MOE_DOWN_START_LAYER`, to test earlier routed-MoE MPP start -layers before changing the conservative defaults. Set +`DS4_METAL_MPP_MOE_DOWN_START_LAYER`, to test routed-MoE MPP start layers; the +resolved start layer also defines the route's default `late_safe` filter. Set `DS4_METAL_MPP_MOE_PAIR_GATE_UP=1` only to profile the experimental fused gate/up MPP dispatch; it passes the current equivalence gate but is not a default path because it is slower than separate gate and up dispatches. @@ -341,6 +364,19 @@ outputs are summed with a single Metal kernel instead of five chained add passes. Set `DS4_METAL_MOE_SUM6_DISABLE=1` to compare or temporarily disable that fused sum route. +Long-context decode uses the indexed mixed-attention kernel once ratio-4 +compressed rows exceed the dense-attention window. The default decode +specialization stages sixteen selected rows per threadgroup block; set +`DS4_METAL_INDEXED_ATTN_RB4=1` to compare the older four-row staging variant. +Set `DS4_METAL_DECODE_INDEXER_TOP_K=64`, `128`, `256`, or `512` to cap the +decode indexer candidate count for speed/quality diagnostics. The normal +non-quality decode path keeps the legacy dense-attention window until there are +more than `1024` compressed rows, then selects `256` rows in sparse indexed +attention. Set `DS4_METAL_DECODE_INDEXER_SPARSE_THRESHOLD` to `64`, `128`, +`256`, `512`, `1024`, `2048`, or `4096` to tune the sparse-decode crossover +separately. `--quality` keeps the full `512` candidate path unless this +environment override is set explicitly. + The attention-output low-projection MPP route applies to full 32-token multiples in the default safe window, using a 64-token MPP tile by default and falling back to the existing indexed simdgroup kernel for shorter or non-32-multiple diff --git a/ds4.c b/ds4.c index 64aec52b..0182acd2 100644 --- a/ds4.c +++ b/ds4.c @@ -6111,8 +6111,8 @@ static uint32_t ds4_default_prefill_cap_for_prompt(int prompt_len) { if (v <= 0) return cap; cap = (uint32_t)v; } - } else if (prompt_len > 2048) { - cap = 2048u; + } else if (prompt_len > 4096) { + cap = 4096u; } if (cap == 0) cap = 1; @@ -8911,9 +8911,81 @@ static bool metal_graph_capture_prefix1_index_state(ds4_gpu_graph *g, uint32_t i g->layer_index_state_score[il], 0, bytes) != 0; } +static bool metal_graph_decode_indexer_top_k_override(uint32_t *value) { + static int parsed = -1; + static uint32_t cached = 0; + if (parsed >= 0) { + if (parsed > 0 && value) *value = cached; + return parsed > 0; + } + + parsed = 0; + const char *env = getenv("DS4_METAL_DECODE_INDEXER_TOP_K"); + if (env && env[0]) { + char *end = NULL; + unsigned long v = strtoul(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end != env && end && *end == '\0' && + (v == 64ul || v == 128ul || v == 256ul || v == 512ul) && + v <= DS4_N_INDEXER_TOP_K) { + cached = (uint32_t)v; + parsed = 1; + } else { + fprintf(stderr, + "ds4: invalid DS4_METAL_DECODE_INDEXER_TOP_K=%s; " + "expected 64, 128, 256, or 512\n", + env); + } + } + if (parsed > 0 && value) *value = cached; + return parsed > 0; +} + static uint32_t metal_graph_decode_indexer_top_k(const ds4_gpu_graph *g) { + uint32_t value = 0; + if (metal_graph_decode_indexer_top_k_override(&value)) return value; + + const uint32_t speed_default = + DS4_N_INDEXER_TOP_K < 256u ? DS4_N_INDEXER_TOP_K : 256u; + return (g && g->quality) ? DS4_N_INDEXER_TOP_K : speed_default; +} + +static uint32_t metal_graph_decode_indexer_sparse_threshold(const ds4_gpu_graph *g) { (void)g; - return DS4_N_INDEXER_TOP_K; + static int parsed = -1; + static uint32_t cached = 0; + if (parsed < 0) { + parsed = 0; + const char *env = getenv("DS4_METAL_DECODE_INDEXER_SPARSE_THRESHOLD"); + if (env && env[0]) { + char *end = NULL; + unsigned long v = strtoul(env, &end, 10); + while (end && isspace((unsigned char)*end)) end++; + if (end != env && end && *end == '\0' && + (v == 64ul || v == 128ul || v == 256ul || v == 512ul || + v == 1024ul || v == 2048ul || v == 4096ul)) { + cached = (uint32_t)v; + parsed = 1; + } else { + fprintf(stderr, + "ds4: invalid DS4_METAL_DECODE_INDEXER_SPARSE_THRESHOLD=%s; " + "expected 64, 128, 256, 512, 1024, 2048, or 4096\n", + env); + } + } + } + if (parsed > 0) return cached; + + uint32_t value = 0; + if (metal_graph_decode_indexer_top_k_override(&value)) return value; + + /* Keep dense attention longer than the legacy 512-row window by default. + * Around the 2K frontier the sparse path's score/top-k setup dominates + * the smaller attention scan, while larger contexts benefit from sparse + * indexed attention. The speed default + * selects fewer rows only after decode has enough compressed rows for the + * sparse indexed path to pay for its score/top-k overhead. */ + return 1024u; } /* ========================================================================= @@ -9388,7 +9460,9 @@ static bool metal_graph_encode_decode_layer( DS4_RMS_EPS) != 0; if (ok && emit) g->layer_n_index_comp[il]++; const uint32_t decode_top_k = metal_graph_decode_indexer_top_k(g); - if (ok && g->layer_n_comp[il] > decode_top_k) { + const uint32_t decode_sparse_threshold = + metal_graph_decode_indexer_sparse_threshold(g); + if (ok && g->layer_n_comp[il] > decode_sparse_threshold) { const uint64_t indexer_q_dim = (uint64_t)DS4_N_INDEXER_HEAD * DS4_N_INDEXER_HEAD_DIM; if (!layer->indexer_attn_q_b || layer->indexer_attn_q_b->type != DS4_TENSOR_F16 || @@ -13152,16 +13226,19 @@ static bool metal_graph_prefill_layer_major( const ds4_model *model, const ds4_weights *weights, const token_vec *prompt, - int n_tokens, + uint32_t start, + uint32_t n_tokens, float *logits, bool show_progress, ds4_imatrix_collector *imatrix) { - if (n_tokens <= 0 || n_tokens > prompt->len || (uint32_t)n_tokens > g->prefill_cap) return false; + if (n_tokens == 0 || n_tokens > g->prefill_cap) return false; + if (start > (uint32_t)prompt->len) return false; + if (n_tokens > (uint32_t)prompt->len - start) return false; - bool ok = metal_graph_upload_prompt_tokens(g->prefill_tokens, prompt, 0, (uint32_t)n_tokens); + bool ok = metal_graph_upload_prompt_tokens(g->prefill_tokens, prompt, start, n_tokens); if (!ok) return false; - if (!metal_graph_warmup_prefill_kernels(g, model, weights, (uint32_t)n_tokens)) return false; + if (!metal_graph_warmup_prefill_kernels(g, model, weights, n_tokens)) return false; const bool split_profile = getenv("DS4_METAL_GRAPH_PREFILL_SPLIT_PROFILE") != NULL; /* @@ -13182,16 +13259,16 @@ static bool metal_graph_prefill_layer_major( model, weights, prompt, - 0, - (uint32_t)n_tokens); + start, + n_tokens); if (ok) ok = ds4_gpu_begin_commands() != 0; for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { ok = metal_graph_encode_layer_batch(g, model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); if (show_progress) { fprintf(stderr, "ds4: gpu prefill layer %u/%u\r", il + 1, (uint32_t)DS4_N_LAYER); fflush(stderr); @@ -13209,13 +13286,13 @@ static bool metal_graph_prefill_layer_major( output_row = (uint32_t)v; } } - ds4_gpu_tensor *last_hc = NULL; ds4_gpu_tensor *saved_cur = g->cur_hc; - if (ok) { + ds4_gpu_tensor *last_hc = NULL; + if (ok && logits) { last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, output_row, hc_dim); ok = last_hc != NULL; } - if (ok) { + if (ok && logits) { g->cur_hc = last_hc; ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); g->cur_hc = saved_cur; @@ -13240,7 +13317,7 @@ static bool metal_graph_prefill_layer_major( if (profile) { const double t_read = now_sec(); fprintf(stderr, - "ds4: gpu graph prefill total tokens=%d encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", + "ds4: gpu graph prefill total tokens=%u encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", n_tokens, (t_encoded - t0) * 1000.0, (t_done - t_encoded) * 1000.0, @@ -13256,8 +13333,8 @@ static bool metal_graph_prefill_layer_major( model, weights, prompt, - 0, - (uint32_t)n_tokens); + start, + n_tokens); const double t_embed_encoded = profile ? now_sec() : 0.0; const double t_embed_done = profile ? now_sec() : 0.0; if (profile) { @@ -13285,8 +13362,8 @@ static bool metal_graph_prefill_layer_major( model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); const double t_attn_encoded = now_sec(); if (ok) ok = ds4_gpu_end_commands() != 0; const double t_attn_done = now_sec(); @@ -13297,8 +13374,8 @@ static bool metal_graph_prefill_layer_major( model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); if (ok) { ds4_gpu_tensor *tmp = g->batch_cur_hc; g->batch_cur_hc = g->batch_next_hc; @@ -13325,8 +13402,8 @@ static bool metal_graph_prefill_layer_major( model, &weights->layer[il], il, - 0, - (uint32_t)n_tokens); + start, + n_tokens); const double t_encoded = profile ? now_sec() : 0.0; if (ok) ok = ds4_gpu_end_commands() != 0; const double t_done = profile ? now_sec() : 0.0; @@ -13364,21 +13441,26 @@ static bool metal_graph_prefill_layer_major( output_row = (uint32_t)v; } } - ds4_gpu_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, - output_row, - hc_dim); - if (!last_hc) return false; ds4_gpu_tensor *saved_cur = g->cur_hc; - g->cur_hc = last_hc; + ds4_gpu_tensor *last_hc = NULL; const double t_head0 = profile ? now_sec() : 0.0; - ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); + if (logits) { + last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, + output_row, + hc_dim); + ok = last_hc != NULL; + } + if (ok && logits) { + g->cur_hc = last_hc; + ok = ds4_gpu_begin_commands() != 0; + } + if (ok && logits) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); const double t_head_encoded = profile ? now_sec() : 0.0; - if (ok) ok = ds4_gpu_end_commands() != 0; + if (ok && logits) ok = ds4_gpu_end_commands() != 0; const double t_head_done = profile ? now_sec() : 0.0; g->cur_hc = saved_cur; - ds4_gpu_tensor_free(last_hc); + if (last_hc) ds4_gpu_tensor_free(last_hc); if (!ok) return false; const double t_before_read = profile ? now_sec() : 0.0; @@ -13396,7 +13478,7 @@ static bool metal_graph_prefill_layer_major( (t_head_done - t_head_encoded) * 1000.0); } fprintf(stderr, - "ds4: gpu layer-major prefill total tokens=%d encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", + "ds4: gpu layer-major prefill total tokens=%u encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", n_tokens, encode_s * 1000.0, execute_s * 1000.0, @@ -13416,32 +13498,15 @@ static bool metal_graph_prefill_raw_swa( bool show_progress) { if (n_tokens <= 0 || n_tokens > prompt->len) return false; if ((uint32_t)n_tokens > g->prefill_cap) return false; - return metal_graph_prefill_layer_major(g, model, weights, prompt, n_tokens, logits, show_progress, NULL); -} - -static bool metal_graph_prefill_batch_row_logits( - ds4_gpu_graph *g, - const ds4_model *model, - const ds4_weights *weights, - uint32_t batch_row, - float *logits) { - if (!logits) return true; - const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; - ds4_gpu_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, - batch_row, - hc_dim); - if (!last_hc) return false; - ds4_gpu_tensor *saved_cur = g->cur_hc; - g->cur_hc = last_hc; - bool ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); - if (ok) ok = ds4_gpu_end_commands() != 0; - else (void)ds4_gpu_synchronize(); - g->cur_hc = saved_cur; - ds4_gpu_tensor_free(last_hc); - if (!ok) return false; - return ds4_gpu_tensor_read(g->logits, 0, logits, - (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0; + return metal_graph_prefill_layer_major(g, + model, + weights, + prompt, + 0, + (uint32_t)n_tokens, + logits, + show_progress, + NULL); } /* Prefill a contiguous token range in fixed-size chunks. @@ -13472,21 +13537,8 @@ static bool metal_graph_prefill_chunked_range( if (start != 0 && chunk_cap > g->raw_cap) chunk_cap = g->raw_cap; if (chunk_cap == 0) return false; - uint32_t first_chunk = n_tokens < chunk_cap ? n_tokens : chunk_cap; - if (start != 0 && g->prefill_cap != 0) { - const uint32_t mod = start % g->prefill_cap; - if (mod != 0) { - const uint32_t to_boundary = g->prefill_cap - mod; - if (to_boundary < first_chunk) first_chunk = to_boundary; - } - } - if (!metal_graph_warmup_prefill_kernels(g, model, weights, first_chunk)) return false; - const bool profile = getenv("DS4_METAL_GRAPH_PREFILL_PROFILE") != NULL; const double t0 = profile ? now_sec() : 0.0; - double encode_s = 0.0; - double execute_s = 0.0; - uint32_t last_chunk_tokens = 0; const uint32_t end = start + n_tokens; if (progress) { @@ -13504,109 +13556,39 @@ static bool metal_graph_prefill_chunked_range( } } const uint32_t chunk = remaining < local_cap ? remaining : local_cap; - last_chunk_tokens = chunk; - - bool ok = metal_graph_upload_prompt_tokens(g->prefill_tokens, prompt, pos0, chunk); - if (ok) ok = metal_graph_upload_prompt_embeddings_hc(g->batch_cur_hc, - g->prefill_tokens, - model, - weights, - prompt, - pos0, - chunk); - if (!ok) return false; - - for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) { - const double t_layer0 = profile ? now_sec() : 0.0; - ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_layer_batch(g, - model, - &weights->layer[il], - il, - pos0, - chunk); - const double t_encoded = profile ? now_sec() : 0.0; - if (ok) ok = ds4_gpu_end_commands() != 0; - const double t_done = profile ? now_sec() : 0.0; - if (ok && imatrix) ok = imatrix_collect_layer_batch(imatrix, g, il, chunk); - if (profile) { - encode_s += t_encoded - t_layer0; - execute_s += t_done - t_encoded; - fprintf(stderr, - "ds4: gpu chunked prefill pos=%u tokens=%u layer %u encode=%.3f ms execute=%.3f ms\n", - pos0, - chunk, - il, - (t_encoded - t_layer0) * 1000.0, - (t_done - t_encoded) * 1000.0); - } - if (show_progress) { - fprintf(stderr, - "ds4: gpu prefill token %u/%u layer %u/%u\r", - pos0 + chunk, - (uint32_t)prompt->len, - il + 1, - (uint32_t)DS4_N_LAYER); - fflush(stderr); - } - } + const uint32_t chunk_end = pos0 + chunk; + float *chunk_logits = (progress || chunk_end == end) ? logits : NULL; + bool ok = metal_graph_prefill_layer_major(g, + model, + weights, + prompt, + pos0, + chunk, + chunk_logits, + show_progress, + imatrix); if (!ok) { if (ds4_gpu_synchronize() == 0) { fprintf(stderr, "ds4: Metal synchronize after chunked prefill failure also failed\n"); } return false; } - if (progress && !metal_graph_prefill_batch_row_logits(g, model, weights, - chunk - 1u, - logits)) - { - return false; - } if (progress) { - progress(progress_ud, "prefill_chunk", (int)(pos0 + chunk), prompt->len); + progress(progress_ud, "prefill_chunk", (int)chunk_end, prompt->len); } - pos0 += chunk; + pos0 = chunk_end; } if (show_progress) fputc('\n', stderr); - if (last_chunk_tokens == 0) return false; - - const uint64_t hc_dim = (uint64_t)DS4_N_HC * DS4_N_EMBD; - ds4_gpu_tensor *last_hc = metal_graph_tensor_row_view(g->batch_cur_hc, - last_chunk_tokens - 1u, - hc_dim); - if (!last_hc) return false; - ds4_gpu_tensor *saved_cur = g->cur_hc; - g->cur_hc = last_hc; - - const double t_head0 = profile ? now_sec() : 0.0; - bool ok = ds4_gpu_begin_commands() != 0; - if (ok) ok = metal_graph_encode_output_head(g, model, weights, weights->output->dim[1]); - const double t_head_encoded = profile ? now_sec() : 0.0; - if (ok) ok = ds4_gpu_end_commands() != 0; - const double t_head_done = profile ? now_sec() : 0.0; - g->cur_hc = saved_cur; - ds4_gpu_tensor_free(last_hc); - if (!ok) return false; - - const double t_before_read = profile ? now_sec() : 0.0; - if (logits) { - ok = ds4_gpu_tensor_read(g->logits, 0, logits, (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0; - } if (profile) { const double t_read = now_sec(); - encode_s += t_head_encoded - t_head0; - execute_s += t_head_done - t_head_encoded; fprintf(stderr, - "ds4: gpu chunked prefill start=%u tokens=%u chunk=%u encode=%.3f ms execute=%.3f ms read=%.3f ms total=%.3f ms\n", + "ds4: gpu chunked prefill start=%u tokens=%u chunk=%u total=%.3f ms\n", start, n_tokens, chunk_cap, - encode_s * 1000.0, - execute_s * 1000.0, - (t_read - t_before_read) * 1000.0, (t_read - t0) * 1000.0); } - return ok; + return true; } /* Long prompts are prefetched in fixed-size chunks. Chunks bound transient @@ -13904,7 +13886,7 @@ static uint32_t metal_graph_raw_cap_for_context(int ctx_size, uint32_t prefill_c } /* Choose the prefill ubatch size. Whole-batch is fastest for normal prompts; - * long prompts default to 2048-token chunks. */ + * long prompts default to 4096-token chunks. */ static uint32_t metal_graph_prefill_cap_for_prompt(int prompt_len) { return ds4_default_prefill_cap_for_prompt(prompt_len); } @@ -16810,7 +16792,8 @@ int ds4_engine_collect_imatrix(ds4_engine *e, &collector); } else { ok = metal_graph_prefill_layer_major(&g, model, weights, - &prompt, prompt.len, + &prompt, 0, + (uint32_t)prompt.len, NULL, false, &collector); } diff --git a/ds4_metal.m b/ds4_metal.m index ec863e0b..aa484366 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -96,6 +96,7 @@ static id g_dsv4_sort_i32_rows_asc_pipeline; static id g_dsv4_indexed_attention_heads8_pipeline; static id g_dsv4_indexed_attention_heads8_rb4_pipeline; +static id g_dsv4_indexed_attention_heads8_rb16_pipeline; static id g_dsv4_softplus_sqrt_pipeline; static id g_dsv4_router_finalize_one_pipeline; static id g_dsv4_router_weights_one_pipeline; @@ -1007,6 +1008,14 @@ static int ds4_gpu_env_bool(const char *name) { return 1; } +static int ds4_gpu_use_indexed_attention_rb4(void) { + static int enabled = -1; + if (enabled < 0) { + enabled = ds4_gpu_env_bool("DS4_METAL_INDEXED_ATTN_RB4") > 0; + } + return enabled; +} + typedef enum { DS4_METAL_MPP_GLOBAL_OFF, DS4_METAL_MPP_GLOBAL_AUTO, @@ -1103,6 +1112,12 @@ static uint32_t ds4_gpu_mpp_q8_0_tile_n(void) { return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_Q8_0_TILE_N", 64); } +static uint32_t ds4_gpu_mpp_q8_0_tile_n_for_tokens(uint64_t n_tok) { + const char *env = getenv("DS4_METAL_MPP_Q8_0_TILE_N"); + if (env && env[0]) return ds4_gpu_mpp_q8_0_tile_n(); + return n_tok >= 4096u ? 32u : 64u; +} + static uint32_t ds4_gpu_mpp_attn_out_tile_n(void) { return ds4_gpu_mpp_tile_n_env("DS4_METAL_MPP_ATTN_OUT_TILE_N", 64); } @@ -1112,7 +1127,9 @@ static uint32_t ds4_gpu_mpp_moe_tile_n(void) { } static int ds4_gpu_mpp_moe_fast_layout(void) { - return ds4_gpu_env_bool("DS4_METAL_MPP_MOE_FAST_LAYOUT") > 0; + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_MOE_FAST_LAYOUT"); + if (enabled >= 0) return enabled > 0; + return 1; } static int ds4_gpu_mpp_moe_pair_gate_up(void) { @@ -1183,6 +1200,14 @@ static int ds4_gpu_mpp_q8_0_late_safe_context(void) { return 0; } +static int ds4_gpu_mpp_q8_0_default_context(uint64_t n_tok) { + if (strstr(g_mpp_compare_context, "attn_q_b") != NULL && + n_tok <= 2048u) { + return 1; + } + return ds4_gpu_mpp_q8_0_late_safe_context(); +} + static int ds4_gpu_mpp_attn_out_late_safe_context(void) { return ds4_gpu_mpp_late_safe_context_range(32); } @@ -1280,10 +1305,10 @@ static int ds4_gpu_mpp_context_matches_filter( return 0; } -static int ds4_gpu_mpp_q8_0_context_matches_filter(void) { +static int ds4_gpu_mpp_q8_0_context_matches_filter(uint64_t n_tok) { const int default_match = ds4_gpu_mpp_fast_profile() ? 1 - : ds4_gpu_mpp_q8_0_late_safe_context(); + : ds4_gpu_mpp_q8_0_default_context(n_tok); return ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_Q8_0_FILTER", default_match, ds4_gpu_mpp_q8_0_late_safe_context()); @@ -1292,7 +1317,7 @@ static int ds4_gpu_mpp_q8_0_context_matches_filter(void) { static int ds4_gpu_can_use_mpp_q8_0_matmul(uint64_t n_tok) { if (n_tok <= 8) return 0; if (!ds4_gpu_use_mpp_q8_0_matmul()) return 0; - if (!ds4_gpu_mpp_q8_0_context_matches_filter()) return 0; + if (!ds4_gpu_mpp_q8_0_context_matches_filter(n_tok)) return 0; if ((n_tok % 32u) == 0 || ds4_gpu_mpp_q8_0_partial_tiles_enabled()) return 1; if (!g_mpp_q8_partial_skip_reported) { @@ -1340,12 +1365,12 @@ static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { DS4_METAL_MOE_MPP_UP = 1 << 1, DS4_METAL_MOE_MPP_DOWN = 1 << 2, - DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 28, - DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 30, - DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 28, - DS4_METAL_MOE_MPP_FAST_GATE_LAYER = 13, - DS4_METAL_MOE_MPP_FAST_UP_LAYER = 13, - DS4_METAL_MOE_MPP_FAST_DOWN_LAYER = 2, + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 0, + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 0, + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 0, + DS4_METAL_MOE_MPP_FAST_GATE_LAYER = 0, + DS4_METAL_MOE_MPP_FAST_UP_LAYER = 0, + DS4_METAL_MOE_MPP_FAST_DOWN_LAYER = 0, }; static int ds4_gpu_mpp_routed_moe_default_target(void) { @@ -1458,17 +1483,17 @@ static int ds4_gpu_mpp_routed_moe_mask_for_layer(uint32_t layer_index) { if ((int)layer_index >= gate_start) mask |= DS4_METAL_MOE_MPP_GATE; if ((mask & DS4_METAL_MOE_MPP_DOWN) && !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_DOWN_FILTER", - DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER)) { + down_start)) { mask &= ~DS4_METAL_MOE_MPP_DOWN; } if ((mask & DS4_METAL_MOE_MPP_UP) && !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_UP_FILTER", - DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER)) { + up_start)) { mask &= ~DS4_METAL_MOE_MPP_UP; } if ((mask & DS4_METAL_MOE_MPP_GATE) && !ds4_gpu_mpp_moe_context_matches_filter("DS4_METAL_MPP_MOE_GATE_FILTER", - DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER)) { + gate_start)) { mask &= ~DS4_METAL_MOE_MPP_GATE; } return mask & requested_mask; @@ -4785,6 +4810,8 @@ int ds4_gpu_init(void) { ds4_gpu_get_pipeline("kernel_dsv4_indexed_mixed_attention_heads8"); g_dsv4_indexed_attention_heads8_rb4_pipeline = ds4_gpu_get_pipeline("kernel_dsv4_indexed_mixed_attention_heads8_rb4"); + g_dsv4_indexed_attention_heads8_rb16_pipeline = + ds4_gpu_get_pipeline("kernel_dsv4_indexed_mixed_attention_heads8_rb16"); g_dsv4_softplus_sqrt_pipeline = ds4_gpu_get_pipeline("kernel_dsv4_softplus_sqrt_f32_4"); g_dsv4_router_finalize_one_pipeline = @@ -4798,6 +4825,7 @@ int ds4_gpu_init(void) { !g_dsv4_sort_i32_rows_asc_pipeline || !g_dsv4_indexed_attention_heads8_pipeline || !g_dsv4_indexed_attention_heads8_rb4_pipeline || + !g_dsv4_indexed_attention_heads8_rb16_pipeline || !g_dsv4_softplus_sqrt_pipeline || !g_dsv4_router_finalize_one_pipeline || !g_dsv4_router_weights_one_pipeline || @@ -5068,6 +5096,7 @@ void ds4_gpu_cleanup(void) { g_dsv4_sort_i32_rows_asc_pipeline = nil; g_dsv4_indexed_attention_heads8_pipeline = nil; g_dsv4_indexed_attention_heads8_rb4_pipeline = nil; + g_dsv4_indexed_attention_heads8_rb16_pipeline = nil; g_dsv4_softplus_sqrt_pipeline = nil; g_dsv4_router_finalize_one_pipeline = nil; g_dsv4_router_weights_one_pipeline = nil; @@ -6216,7 +6245,7 @@ int ds4_gpu_matmul_q8_0_mpp_tensor( id wbuf = ds4_gpu_wrap_model_range(model_map, model_size, weight_offset, weight_bytes, &inner_offset); if (!wbuf) return 0; - const uint32_t tile_n = ds4_gpu_mpp_q8_0_tile_n(); + const uint32_t tile_n = ds4_gpu_mpp_q8_0_tile_n_for_tokens(n_tok); const bool direct_rhs = (tile_n == 32u || tile_n == 64u) && ds4_gpu_mpp_q8_0_direct_rhs(); @@ -12302,10 +12331,14 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( ds4_gpu_hot_pipeline(g_dsv4_sort_i32_rows_asc_pipeline, "kernel_dsv4_sort_i32_rows_asc"); const bool decode_one_token = n_tokens == 1u; + const bool decode_rb4 = decode_one_token && ds4_gpu_use_indexed_attention_rb4(); id attn_pipeline = - decode_one_token ? + decode_rb4 ? ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_rb4_pipeline, "kernel_dsv4_indexed_mixed_attention_heads8_rb4") : + decode_one_token ? + ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_rb16_pipeline, + "kernel_dsv4_indexed_mixed_attention_heads8_rb16") : ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_pipeline, "kernel_dsv4_indexed_mixed_attention_heads8"); if (!sort_pipeline || !attn_pipeline) return 0; @@ -12386,7 +12419,8 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( atIndex:4]; [enc setBuffer:sinks_buf offset:(NSUInteger)sinks_inner atIndex:5]; [enc setBuffer:headsbuf offset:ds4_gpu_tensor_offset(heads) atIndex:6]; - [enc setThreadgroupMemoryLength:(decode_one_token ? 4u : 1u) * 128u * 4u * sizeof(float) + [enc setThreadgroupMemoryLength:(decode_one_token ? (decode_rb4 ? 4u : 16u) : 1u) * + 128u * 4u * sizeof(float) atIndex:0]; [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)n_tokens, ((NSUInteger)n_head + 7u) / 8u, 1) threadsPerThreadgroup:MTLSizeMake(32, 8, 1)]; diff --git a/metal/dsv4_misc.metal b/metal/dsv4_misc.metal index b06d29d3..c9dc09c6 100644 --- a/metal/dsv4_misc.metal +++ b/metal/dsv4_misc.metal @@ -594,9 +594,7 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8( // Decode specialization of kernel_dsv4_indexed_mixed_attention_heads8. // Generation attends one token at a time, so the ratio-4 indexed path spends a // visible amount of time repeatedly staging the same K/V row for the eight -// heads in a group. This variant stages four selected rows at once and then -// consumes them sequentially, preserving the row order and online softmax math -// while cutting threadgroup barriers in the long top-k scan. +// heads in a group. This diagnostic variant stages four selected rows at once. kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb4( constant ds4_metal_args_dsv4_indexed_attention & args, device const char *q, @@ -720,6 +718,135 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb4( dst4[lane + 96] = o3 * inv_s; } +// Decode specialization of kernel_dsv4_indexed_mixed_attention_heads8. +// Generation attends one token at a time, so the ratio-4 indexed path spends a +// visible amount of time repeatedly staging the same K/V row for the eight +// heads in a group. This variant stages sixteen selected rows at once and then +// consumes them sequentially, preserving the row order and online softmax math +// while cutting threadgroup barriers in the long top-k scan. +kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb16( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + const uint token = tgpig.x; + const uint head = tgpig.y * 8u + (uint)sg; + if (token >= args.n_tokens || head >= args.n_head) { + return; + } + + device const float4 *q4 = (device const float4 *)(q + + (uint64_t)token * args.q_token_stride + + (uint64_t)head * args.q_head_stride); + const half4 q0 = (half4)q4[lane + 0]; + const half4 q1 = (half4)q4[lane + 32]; + const half4 q2 = (half4)q4[lane + 64]; + const half4 q3 = (half4)q4[lane + 96]; + + float M = -FLT_MAX/2.0f; + float S = 0.0f; + float4 o0 = 0.0f; + float4 o1 = 0.0f; + float4 o2 = 0.0f; + float4 o3 = 0.0f; + + const uint qpos = args.pos0 + token; + const uint last_pos = args.pos0 + args.n_tokens - 1u; + const uint first_raw_pos = last_pos + 1u - args.n_raw; + const uint raw_last_pos = first_raw_pos + args.n_raw - 1u; + const uint window_first = (args.window != 0u && qpos + 1u > args.window) ? + qpos + 1u - args.window : 0u; + uint first = max(first_raw_pos, window_first); + uint last = min(qpos, raw_last_pos); + + if (first <= last) { + for (uint pos0 = first; pos0 <= last; pos0 += 16u) { + const uint n_rows = min(16u, last - pos0 + 1u); + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + const uint logical = pos0 + r - first_raw_pos; + const uint row = (args.raw_start + logical) % args.raw_cap; + device const float4 *src = (device const float4 *)(raw_kv + + (uint64_t)row * args.raw_row_stride); + kv_shared[off] = src[c]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + uint visible = (qpos + 1u) / args.ratio; + visible = min(visible, args.n_comp); + device const int32_t *row_topk = (device const int32_t *)(topk + + (uint64_t)token * args.topk_token_stride); + bool stop = false; + for (uint i = 0; i < args.top_k && !stop; i += 16u) { + uint rows[16]; + uint n_rows = 0; + for (uint j = 0; j < 16u && i + j < args.top_k; j++) { + const int32_t idx = row_topk[i + j]; + if (idx < 0) { + continue; + } + if ((uint)idx >= visible) { + stop = true; + break; + } + rows[n_rows++] = (uint)idx; + } + if (n_rows == 0) { + continue; + } + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + device const float4 *src = (device const float4 *)(comp_kv + + (uint64_t)rows[r] * args.comp_row_stride); + kv_shared[off] = src[c]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + dsv4_attend_sink(((device const float *)sinks)[head], M, S, o0, o1, o2, o3); + + const float inv_s = S == 0.0f ? 0.0f : 1.0f/S; + device float4 *dst4 = (device float4 *)(dst + + (uint64_t)token * args.dst_token_stride + + (uint64_t)head * args.dst_head_stride); + dst4[lane + 0] = o0 * inv_s; + dst4[lane + 32] = o1 * inv_s; + dst4[lane + 64] = o2 * inv_s; + dst4[lane + 96] = o3 * inv_s; +} + static inline float dsv4_indexer_dot128_shared_q( float4 c0, float4 c1, diff --git a/metal/moe.metal b/metal/moe.metal index a4360fe6..4619de28 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -2044,9 +2044,8 @@ typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, ha typedef decltype(kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>) mul_mm_id_f16_rhs_n64; #ifdef DS4_METAL_HAS_TENSOR -// Diagnostic-only old MPP tensor layout from the first Metal 4 PR. It is kept -// behind DS4_METAL_MPP_MOE_FAST_LAYOUT so we can measure whether the old kernel -// shape can be recovered for routes that already pass full-model equivalence. +// Faster routed-MoE MPP tensor layout from the first Metal 4 PR. The host keeps +// it inside the active route windows that pass full-model checks. template kernel void kernel_mul_mm_id_mpp_fast_layout( constant ds4_metal_args_mul_mm_id & args, From 77eafa28d8c0c9250e508e67bae5a49c948aa6bd Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Tue, 12 May 2026 07:22:30 +0200 Subject: [PATCH 05/16] Add low-power Metal MPP Q8 profile Detect macOS Low Power Mode and widen the Q8_0 prefill MPP route only under that condition, while preserving the guarded default for normal-power runs and explicit Q8_0 filters. Low-power M5 Max baseline vs patched auto with 128 generated tokens: 0.5k: prefill 133.46 -> 196.89 t/s, gen 13.53 -> 15.08 t/s 1k: prefill 118.65 -> 188.91 t/s, gen 12.23 -> 14.93 t/s 2k: prefill 130.90 -> 220.33 t/s, gen 11.02 -> 14.65 t/s 4k: prefill 118.09 -> 212.81 t/s, gen 13.25 -> 14.00 t/s 8k: prefill 185.52 -> 206.49 t/s, gen 12.94 -> 13.84 t/s Tests: make all ds4_test; make test; DS4_METAL_MPP_LOW_POWER_DISABLE=1 ./ds4_test --metal-mpp-equivalence; git diff --check. --- README.md | 18 ++++++++++++++---- ds4_metal.m | 36 +++++++++++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c769abcd..0a1c5533 100644 --- a/README.md +++ b/README.md @@ -265,11 +265,16 @@ tokens use MPP for `attn_q_b` across layers, while larger batches use the late full-model-safe layer window 38..42 plus `attn_q_b` in layers 32..37. It uses 64-token tiles below 4096-token batches and 32-token tiles for larger prompt batches on M5, accepts partial token tails, and falls back to the legacy -kernel when the Metal 4 tensor path is unavailable. +kernel when the Metal 4 tensor path is unavailable. When macOS reports Low +Power Mode, auto widens Q8_0 prefill to all Q8_0 contexts because that profile +improves both prefill and generation speed in current M5 Max low-power sweeps. +Set `DS4_METAL_MPP_LOW_POWER_DISABLE=1` to keep the normal guarded Q8_0 +profile, or `DS4_METAL_MPP_LOW_POWER_ENABLE=1` to force the low-power profile +for comparison. Set `DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=0` to force the old partial-tail fallback while debugging. Set `DS4_METAL_MPP_Q8_0_FILTER=all` to reproduce the -unsafe all-layer Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=late_safe` to request the -older conservative late window explicitly, or +wider all-context Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=late_safe` to request +the older conservative late window explicitly, or `DS4_METAL_MPP_Q8_0_FILTER=` to force named full-graph Q8 modules such as `attn_q_a`, `attn_kv`, `attn_q_b`, `attn_out`, `shared_gate`, `shared_up`, or `shared_down`. Use @@ -321,7 +326,12 @@ Q8_0 tile width, and 64-token tiles for attention-output low projections. In a local M5 Max `ds4-bench` sweep with 64 generated tokens, auto sampled about `443/459/522/486/465` prompt tokens/sec and `38.6/38.2/37.6/34.0/33.6` generation tokens/sec at the -`0.5k/1k/2k/4k/8k` frontiers, with visible desktop-load variance. The F16 +`0.5k/1k/2k/4k/8k` frontiers, with visible desktop-load variance. In macOS Low +Power Mode on the same M5 Max, the guarded default sampled about +`133/119/131/118/186` prompt tokens/sec and +`13.5/12.2/11.0/13.3/12.9` generation tokens/sec at those frontiers with 128 +generated tokens; the low-power Q8 profile sampled about +`197/189/220/213/206` and `15.1/14.9/14.7/14.0/13.8` respectively. The F16 compressor route did not introduce measurable drift in the current prompt set. The `DS4_METAL_MPP_FAST=1` profile is the measured high-throughput diagnostic diff --git a/ds4_metal.m b/ds4_metal.m index aa484366..d7b0a115 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -1008,6 +1008,32 @@ static int ds4_gpu_env_bool(const char *name) { return 1; } +static int ds4_gpu_mpp_low_power_profile(void) { + const int disabled = ds4_gpu_env_bool("DS4_METAL_MPP_LOW_POWER_DISABLE"); + if (disabled > 0) return 0; + + const int enabled = ds4_gpu_env_bool("DS4_METAL_MPP_LOW_POWER_ENABLE"); + if (enabled >= 0) return enabled > 0; + + static int detected = -1; + static int reported; + if (detected < 0) { + detected = 0; + @autoreleasepool { + NSProcessInfo *info = [NSProcessInfo processInfo]; + if ([info respondsToSelector:@selector(isLowPowerModeEnabled)]) { + detected = [info isLowPowerModeEnabled] ? 1 : 0; + } + } + } + if (detected && !reported) { + fprintf(stderr, + "ds4: Metal low-power MPP profile active; widening Q8_0 prefill route\n"); + reported = 1; + } + return detected; +} + static int ds4_gpu_use_indexed_attention_rb4(void) { static int enabled = -1; if (enabled < 0) { @@ -1306,9 +1332,13 @@ static int ds4_gpu_mpp_context_matches_filter( } static int ds4_gpu_mpp_q8_0_context_matches_filter(uint64_t n_tok) { - const int default_match = ds4_gpu_mpp_fast_profile() - ? 1 - : ds4_gpu_mpp_q8_0_default_context(n_tok); + const char *filter = getenv("DS4_METAL_MPP_Q8_0_FILTER"); + const int filter_set = filter && filter[0]; + const int default_match = + (ds4_gpu_mpp_fast_profile() || + (!filter_set && ds4_gpu_mpp_low_power_profile())) + ? 1 + : ds4_gpu_mpp_q8_0_default_context(n_tok); return ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_Q8_0_FILTER", default_match, ds4_gpu_mpp_q8_0_late_safe_context()); From 0dd25e1474d6823e9d613f5219ff5fdef7b0b7c7 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:05:58 +0200 Subject: [PATCH 06/16] Add M5 Max drift-patch macro plumbing and --dump-logits tooling Carries forward the pending "MPP -> Metal Tensor" naming refactor and adds: - --dump-logits FILE CLI flag and run_logits_dump() so prefill-time logits can be captured for A/B drift comparison. - bench/compare_logit_drift.py + bench/compare_bench.py + run helper. - Macro plumbing in ds4_metal.m's library compile step for five env-gated drift flags (DS4_METAL_HC_STABLE default-on, DS4_METAL_NORM_RSQRT_DISABLE default-on, DS4_METAL_KV_RAW_F32 default-off, DS4_METAL_ROPE_EXP2_LOG2 default-off, DS4_METAL_TENSOR_MATMUL_DISABLE default-off). - Logs the active flag set on first device init so test runs are self-documenting. Per-kernel changes that consume each macro land in follow-up commits so they can be reverted independently if a drift measurement regresses. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 144 +++++++------- ds4_cli.c | 103 +++++++++- ds4_metal.m | 95 ++++++---- ds4_server.c | 13 +- speed-bench/compare_bench.py | 258 ++++++++++++++++++++++++++ speed-bench/compare_logit_drift.py | 225 ++++++++++++++++++++++ speed-bench/run_metal_tensor_bench.sh | 63 +++++++ tests/ds4_test.c | 22 +-- 8 files changed, 789 insertions(+), 134 deletions(-) create mode 100755 speed-bench/compare_bench.py create mode 100644 speed-bench/compare_logit_drift.py create mode 100755 speed-bench/run_metal_tensor_bench.sh diff --git a/README.md b/README.md index 0a1c5533..33d282c9 100644 --- a/README.md +++ b/README.md @@ -224,31 +224,33 @@ looks like an M5 Neural Accelerator target. The implementation follows the same conservative shape used by llama.cpp's current Metal backend: the tensor API is disabled by default on pre-M5/pre-A19 devices, can be forced with `DS4_METAL_TENSOR_ENABLE=1`, and can always be -disabled with `DS4_METAL_TENSOR_DISABLE=1`. At startup ds4 compiles a tiny MPP -tensor matmul probe before it lets the main Metal shader source see -`DS4_METAL_HAS_TENSOR`, so unsupported SDK/device combinations fall back to the -legacy kernels. - -MPP policy is explicit and guarded. Use `--mpp auto` for the default -route policy, `--mpp on` to force MPP routes where the Metal 4 tensor path is -available, and `--mpp off` for the legacy Metal reference path. Auto currently -keeps attention-output MPP in the validated late-layer window, extends the -Q8_0 `attn_q_b` projection for small prompt batches, and runs routed-MoE MPP -from layer 0 for prefill throughput while preserving same-top1/same-greedy -agreement. Unguarded Q8_0 and attention-output all-layer MPP routes remain +disabled with `DS4_METAL_TENSOR_DISABLE=1`. At startup ds4 compiles a tiny +Metal Performance Primitives tensor matmul probe before it lets the main Metal +shader source see `DS4_METAL_HAS_TENSOR`, so unsupported SDK/device +combinations fall back to the legacy kernels. + +Metal Tensor policy is explicit and guarded. Use `-mt auto` or `--mt auto` for +the default route policy, `-mt on` to force Tensor routes where the Metal tensor +path is available, and `-mt off` for the legacy Metal reference path. The old +`--mpp` spelling remains accepted as a compatibility alias. Auto currently +keeps attention-output Tensor in the validated late-layer window, keeps Q8_0 +prefill in the lower-drift conservative layer window, and runs routed-MoE Tensor +only in its conservative layer window while preserving +same-top1/same-greedy agreement. Unguarded Q8_0, attention-output all-layer, +and all-layer routed-MoE Tensor routes remain opt-in diagnostics. The environment controls `DS4_METAL_MPP_ENABLE` and `DS4_METAL_MPP_DISABLE` accept `1/true/yes/on` and -`0/false/no/off`; `DS4_METAL_MPP_ENABLE=0` disables MPP instead of enabling it -by mere presence. Passing `--quality` also disables MPP routes so strict/debug -runs stay on the legacy Metal kernels. Set `DS4_METAL_MPP_FAST=1` to opt into -the current same-top1/same-greedy fast profile: it widens Q8_0 and -attention-output MPP to all layers while keeping the routed-MoE all-layer -default. This profile is not the default because its top-k overlap is weaker -than auto in the current full-model suite. -The default safe-window policy uses the direct-RHS tensor layout for MPP routes; -set `DS4_METAL_MPP_DIRECT_RHS=0` to compare against the older staged-RHS +`0/false/no/off`; `DS4_METAL_MPP_ENABLE=0` disables Tensor routes instead of +enabling them by mere presence. Passing `--quality` also disables Tensor routes +so strict/debug runs stay on the legacy Metal kernels. Set +`DS4_METAL_MPP_FAST=1` to opt into the current same-top1/same-greedy fast +profile: it widens Q8_0 and attention-output Tensor to all layers while keeping +the routed-MoE all-layer diagnostic window. This profile is not the default because its +top-k overlap is weaker than auto in the current full-model suite. +The default safe-window policy uses the direct-RHS tensor layout for Tensor +routes; set `DS4_METAL_MPP_DIRECT_RHS=0` to compare against the older staged-RHS layout. Q8_0 and attention-output direct-RHS routes support both 32-token and -64-token MPP tiles. Auto defaults attention-output to 64-token tiles, while +64-token Tensor tiles. Auto defaults attention-output to 64-token tiles, while Q8_0 uses 64-token tiles below 4096-token batches and 32-token tiles for larger prompt batches on M5. Set `DS4_METAL_MPP_Q8_0_TILE_N=32` or `DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to force the narrower layout. The @@ -258,11 +260,11 @@ route-specific `DS4_METAL_MPP_Q8_0_DIRECT_RHS=1`, turning on every direct-RHS route at once when the global `DS4_METAL_MPP_DIRECT_RHS=0` override is set. -The Q8_0 prefill MPP route can be isolated with +The Q8_0 prefill Tensor route can be isolated with `DS4_METAL_MPP_Q8_0_ENABLE=1` or `DS4_METAL_MPP_Q8_0_DISABLE=1`. It only -affects prompt batches larger than eight tokens. By default, batches up to 2048 -tokens use MPP for `attn_q_b` across layers, while larger batches use the -late full-model-safe layer window 38..42 plus `attn_q_b` in layers 32..37. It +affects prompt batches larger than eight tokens. By default, Q8_0 uses the late +full-model-safe layer window 38..42 plus `attn_q_b` in layers 32..37 for all +prompt batch sizes. It uses 64-token tiles below 4096-token batches and 32-token tiles for larger prompt batches on M5, accepts partial token tails, and falls back to the legacy kernel when the Metal 4 tensor path is unavailable. When macOS reports Low @@ -273,19 +275,19 @@ profile, or `DS4_METAL_MPP_LOW_POWER_ENABLE=1` to force the low-power profile for comparison. Set `DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=0` to force the old partial-tail fallback while debugging. Set `DS4_METAL_MPP_Q8_0_FILTER=all` to reproduce the -wider all-context Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=late_safe` to request -the older conservative late window explicitly, or +wider all-context Q8 route, `DS4_METAL_MPP_Q8_0_FILTER=attn_q_b` to reproduce +the broader small-prompt speed profile, or `DS4_METAL_MPP_Q8_0_FILTER=` to force named full-graph Q8 modules such as `attn_q_a`, `attn_kv`, `attn_q_b`, `attn_out`, `shared_gate`, `shared_up`, or `shared_down`. Use `@layer=A..B` to test one module family only in a layer window, for example `shared_up@layer=30..37`. Set `DS4_METAL_MPP_Q8_0_TILE_N=32` to -compare against the narrower MPP token tile. The isolated +compare against the narrower Tensor token tile. The isolated `./ds4_test --metal-kernels` regression reports small/medium/model-ish kernel deltas; the full-model `./ds4_test --metal-mpp-equivalence` diagnostic compares default auto against -`--mpp off`. Set `DS4_TEST_MPP_EQ_FORCE_ON=1` to compare forced MPP against -`--mpp off` while working on a route. `DS4_TEST_MPP_EQ_CASE=` +`-mt off`. Set `DS4_TEST_MPP_EQ_FORCE_ON=1` to compare forced Tensor against +`-mt off` while working on a route. `DS4_TEST_MPP_EQ_CASE=` limits the diagnostic to one prompt, and `DS4_TEST_MPP_EQ_MATRIX=1` prints separate auto, fast-profile, Q8-only, attention-output-only, MoE gate/up/down-only, and full-forced summary rows. The equivalence gate requires finite logits, the @@ -295,43 +297,35 @@ drift so route changes can be judged beyond pass/fail. Full-graph route localization is available with `DS4_METAL_MPP_COMPARE_ROUTE=q8|attn_out|moe_gate|moe_up|moe_down` and optional -`DS4_METAL_MPP_COMPARE_MAX=N`. The comparator snapshots the candidate MPP +`DS4_METAL_MPP_COMPARE_MAX=N`. The comparator snapshots the candidate Tensor output, runs the legacy Metal route on the same tensor input, and reports the first comparison that exceeds the kernel target, including module/layer context, shape, max absolute error, RMS, and the largest element deltas. Set `DS4_METAL_MPP_COMPARE_VERBOSE=1` to print passing comparisons as well. -Current MPP route status balances drift with prefill throughput: `auto` enables +Current Tensor route status balances drift with prefill throughput: `auto` enables Q8_0 prefill, F16 compressor, attention-output low projection, and routed-MoE -MPP. Attention-output low projection now uses layers 32..42 by default, while -Q8_0 uses `attn_q_b` across layers for <=2048-token prompt batches and keeps -the narrower `attn_q_b` 32..37 plus all-Q8 38..42 window for larger batches. -Routed-MoE MPP now covers gate/up/down from layer 0 by default to favor prefill -throughput on M5-class systems; it still preserves greedy agreement in the MPP -equivalence suite, but it carries larger logit drift than the previous -layer-20/22 conservative window. The current auto suite reports -same-top1/same-greedy agreement with minimum top-5 overlap `4/5`, minimum -top-20 overlap `17/20`, `worst_rms ~= 0.942`, and -`worst_top20_max_abs ~= 3.06`. The Q8_0 and attention-output low MPP +Tensor. Attention-output low projection now uses layers 32..42 by default, while +Q8_0 uses the narrower `attn_q_b` 32..37 plus all-Q8 38..42 window by default. +Routed-MoE Tensor now uses the lower-drift conservative default window: +gate/up from layer 20 and down from layer 22. This gives up some of the +all-layer prefill speedup to avoid the larger drift seen with the previous +broader Q8_0 and layer-0 routed-MoE Tensor windows. The current auto suite +reports same-top1/same-greedy agreement with minimum top-5 overlap `5/5`, +minimum top-20 overlap `19/20`, `worst_rms ~= 0.170`, and +`worst_top20_max_abs ~= 0.342`. The Q8_0 and attention-output low Tensor kernels stage activation tiles through half to match the legacy Metal matmul input path, which brings the isolated model-ish Q8_0 regression under the strict kernel target and removes the first attention-output comparator breach. Most Q8_0 projection families stay restricted to layers 38..42 because earlier -layers can amplify small local differences through normalization/attention -enough to fail long-context generation. The guarded `attn_q_b` extension is -kept because it is query-side only, passes prompt-logit and long-context gates -when limited to <=2048-token batches, and improves prefill throughput. The -current auto policy also uses Q8_0 partial tails, direct-RHS MPP inputs, dynamic -Q8_0 tile width, and 64-token tiles for attention-output low projections. In a -local M5 Max `ds4-bench` sweep with 64 generated tokens, auto sampled about -`443/459/522/486/465` prompt tokens/sec and -`38.6/38.2/37.6/34.0/33.6` generation tokens/sec at the -`0.5k/1k/2k/4k/8k` frontiers, with visible desktop-load variance. In macOS Low -Power Mode on the same M5 Max, the guarded default sampled about -`133/119/131/118/186` prompt tokens/sec and -`13.5/12.2/11.0/13.3/12.9` generation tokens/sec at those frontiers with 128 -generated tokens; the low-power Q8 profile sampled about -`197/189/220/213/206` and `15.1/14.9/14.7/14.0/13.8` respectively. The F16 +layers can amplify small local differences through normalization/attention. The +broader `attn_q_b` profile remains available through the filter knob when +prefill speed is more important than logit drift. The current auto policy also +uses Q8_0 partial tails, direct-RHS Tensor inputs, dynamic Q8_0 tile width, and +64-token tiles for attention-output low projections. In a quick local M5 Max +512-token sanity row, this lower-drift auto profile sampled `339.36` prompt +tokens/sec and `32.97` generation tokens/sec, versus `264.09` and `32.62` for +`--quality`; full sweeps still show visible desktop-load variance. The F16 compressor route did not introduce measurable drift in the current prompt set. The `DS4_METAL_MPP_FAST=1` profile is the measured high-throughput diagnostic @@ -339,34 +333,34 @@ profile under the relaxed same-top1/same-greedy gate. In the current prompt suite it keeps top-1 and greedy continuations stable, but reports weaker top-k overlap than auto (`worst_rms ~= 0.951`, `worst_top20_max_abs ~= 4.03`, minimum top-20 overlap `16/20`). It remains diagnostic-only because it widens -the Q8_0 and attention-output route windows that produce the largest full-suite -drift. +the Q8_0, attention-output, and routed-MoE route windows that produce the +largest full-suite drift. -The routed-MoE MPP projections are enabled from layer 0 by default for prefill -speed. For route isolation, use +The routed-MoE Tensor projections are enabled by default from layer 20 for +gate/up and layer 22 for down. For route isolation, use `DS4_METAL_MPP_MOE_GATE_ENABLE/DISABLE`, `DS4_METAL_MPP_MOE_UP_ENABLE/DISABLE`, and `DS4_METAL_MPP_MOE_DOWN_ENABLE/DISABLE`; `DS4_METAL_MPP_MOE_DISABLE=1` -disables all routed-MoE MPP projections. Set the common +disables all routed-MoE Tensor projections. Set the common `DS4_METAL_MPP_MOE_FILTER` or route-specific `DS4_METAL_MPP_MOE_GATE_FILTER`, `DS4_METAL_MPP_MOE_UP_FILTER`, and `DS4_METAL_MPP_MOE_DOWN_FILTER` to `all`, `late_safe`, `none`, or comma-separated full-graph context substrings to localize safe layer windows. Use `layer=N` for an exact layer match or `layer=A..B` for an inclusive layer -range when testing sparse MPP windows. The same `@layer=A..B` +range when testing sparse Tensor windows. The same `@layer=A..B` syntax can restrict a context substring to a layer window. Set `DS4_METAL_MPP_MOE_TILE_N=64` to test the experimental wider routed-MoE -MPP token tile for performance against the default `32`. The routed-MoE MPP +Tensor token tile for performance against the default `32`. The routed-MoE Tensor path uses the faster first-PR threadgroup tensor layout by default inside the active routed-MoE windows; set `DS4_METAL_MPP_MOE_FAST_LAYOUT=0` to compare against the newer staged layout. Set `DS4_METAL_MPP_MOE_START_LAYER=N`, or the route-specific `DS4_METAL_MPP_MOE_GATE_START_LAYER`, `DS4_METAL_MPP_MOE_UP_START_LAYER`, and -`DS4_METAL_MPP_MOE_DOWN_START_LAYER`, to test routed-MoE MPP start layers; the +`DS4_METAL_MPP_MOE_DOWN_START_LAYER`, to test routed-MoE Tensor start layers; the resolved start layer also defines the route's default `late_safe` filter. Set `DS4_METAL_MPP_MOE_PAIR_GATE_UP=1` only to profile the experimental fused -gate/up MPP dispatch; it passes the current equivalence gate but is not a +gate/up Tensor dispatch; it passes the current equivalence gate but is not a default path because it is slower than separate gate and up dispatches. For the common six-routed-expert prefill shape, the down-projection expert @@ -387,19 +381,19 @@ attention. Set `DS4_METAL_DECODE_INDEXER_SPARSE_THRESHOLD` to `64`, `128`, separately. `--quality` keeps the full `512` candidate path unless this environment override is set explicitly. -The attention-output low-projection MPP route applies to full 32-token multiples -in the default safe window, using a 64-token MPP tile by default and falling +The attention-output low-projection Tensor route applies to full 32-token multiples +in the default safe window, using a 64-token Tensor tile by default and falling back to the existing indexed simdgroup kernel for shorter or non-32-multiple -tails. Attention-output MPP is limited to the measured full-model-safe layer +tails. Attention-output Tensor is limited to the measured full-model-safe layer window 32..42 by default. Set `DS4_METAL_MPP_ATTN_OUT_ENABLE=1` or `DS4_METAL_MPP_ATTN_OUT_DISABLE=1` to isolate this route. Set `DS4_METAL_MPP_ATTN_OUT_FILTER=all`, `late_safe`, `none`, or a comma-separated list of full-graph context substrings such as `layer=42` to localize full-model-safe layer windows. Layer filters are exact, and `layer=A..B` matches an inclusive range. Set -`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to compare against the narrower MPP token +`DS4_METAL_MPP_ATTN_OUT_TILE_N=32` to compare against the narrower Tensor token tile. The all-layer -attention-output MPP route still fails long-prompt full-model equivalence +attention-output Tensor route still fails long-prompt full-model equivalence despite per-layer low-projection differences below the current kernel target. The ratio-2 F16 compressor route can similarly be controlled with `DS4_METAL_MPP_F16_ENABLE=1` or `DS4_METAL_MPP_F16_DISABLE=1`. @@ -407,9 +401,9 @@ The ratio-2 F16 compressor route can similarly be controlled with the standard simdgroup F16 matmul accumulation shape. It passes the current full-model equivalence gate, but the measured long-code prefill change was within noise (`~0.4%`), so it remains opt-in. `DS4_METAL_MPP_F16_WIDE=1` tests -wider 512/1024-column compressor MPP, including the paired MPP route when both +wider 512/1024-column compressor Tensor, including the paired Tensor route when both variables are set. The wide route is diagnostic only: the current long-code -prompt fails full-model equivalence with wide F16 MPP (`rms ~= 0.569`, +prompt fails full-model equivalence with wide F16 Tensor (`rms ~= 0.569`, `top20_max_abs ~= 1.48`), so it is not enabled by `auto`. ## CLI @@ -935,6 +929,8 @@ first answer: ```sh ./ds4 --dump-tokens -p "..." ./ds4 --dump-logprobs /tmp/out.json --logprobs-top-k 20 --temp 0 -p "..." +./ds4 --dump-logits /tmp/q2-off.json --metal -mt off --nothink --prompt-file prompt.txt +python3 speed-bench/compare_logit_drift.py /tmp/q2-off.json /tmp/q2-mt.json /tmp/q4-off.json --labels q2_mt q4_off ./ds4-server --trace /tmp/ds4-trace.txt ... ``` diff --git a/ds4_cli.c b/ds4_cli.c index 0bfd71e7..887e4b1e 100644 --- a/ds4_cli.c +++ b/ds4_cli.c @@ -32,6 +32,7 @@ typedef struct { float top_p; uint64_t seed; bool dump_tokens; + const char *dump_logits_path; const char *dump_logprobs_path; int dump_logprobs_top_k; const char *imatrix_dataset_path; @@ -102,9 +103,10 @@ static void usage(FILE *fp) { " -t, --threads N\n" " CPU helper threads for host-side or reference work.\n" " --quality\n" - " Prefer exact kernels where faster approximate paths exist; disables Metal 4 MPP routes; MTP uses strict verification.\n" - " --mpp MODE\n" - " Metal 4 MPP policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" + " Prefer exact kernels where faster approximate paths exist; disables Metal Tensor routes; MTP uses strict verification.\n" + " -mt MODE, --mt MODE\n" + " Metal Tensor policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" + " Legacy alias: --mpp MODE.\n" " --dir-steering-file FILE\n" " Load one f32 direction vector per layer for directional steering.\n" " --dir-steering-ffn F\n" @@ -155,6 +157,8 @@ static void usage(FILE *fp) { " Load the model and print a summary only.\n" " --dump-tokens\n" " Tokenize -p/--prompt-file exactly as written, then exit without inference.\n" + " --dump-logits FILE\n" + " Write full next-token logits as JSON after prompt prefill, then exit.\n" " --dump-logprobs FILE\n" " Write greedy continuation top-logprobs as JSON without printing text.\n" " --logprobs-top-k N\n" @@ -246,8 +250,8 @@ static ds4_mpp_mode parse_mpp_mode(const char *s) { if (!strcmp(s, "auto")) return DS4_MPP_AUTO; if (!strcmp(s, "on")) return DS4_MPP_ON; if (!strcmp(s, "off")) return DS4_MPP_OFF; - fprintf(stderr, "ds4: invalid MPP mode: %s\n", s); - fprintf(stderr, "ds4: valid MPP modes are: auto, on, off\n"); + fprintf(stderr, "ds4: invalid Metal Tensor mode: %s\n", s); + fprintf(stderr, "ds4: valid Metal Tensor modes are: auto, on, off\n"); exit(2); } @@ -640,6 +644,86 @@ static void json_write_token(FILE *fp, ds4_engine *engine, int token) { free(text); } +static int run_logits_dump(ds4_engine *engine, const cli_config *cfg, const ds4_tokens *prompt) { + ds4_session *session = NULL; + if (ds4_session_create(&session, engine, cfg->gen.ctx_size) != 0) { + fprintf(stderr, "ds4: --dump-logits requires a graph session backend\n"); + return 1; + } + + char err[160]; + cli_prefill_progress progress = { + .base_tokens = 0, + .input_tokens = prompt->len, + .use_color = ds4_log_is_tty(stderr), + }; + ds4_session_set_progress(session, cli_prefill_progress_cb, &progress); + if (ds4_session_sync(session, prompt, err, sizeof(err)) != 0) { + ds4_session_set_progress(session, NULL, NULL); + fprintf(stderr, "ds4: prompt processing failed: %s\n", err); + ds4_session_free(session); + return 1; + } + ds4_session_set_progress(session, NULL, NULL); + + const int vocab = ds4_engine_vocab_size(engine); + float *logits = malloc((size_t)vocab * sizeof(logits[0])); + if (!logits) { + ds4_session_free(session); + return 1; + } + if (ds4_session_copy_logits(session, logits, vocab) != vocab) { + fprintf(stderr, "ds4: failed to copy session logits\n"); + free(logits); + ds4_session_free(session); + return 1; + } + + FILE *fp = fopen(cfg->gen.dump_logits_path, "wb"); + if (!fp) { + fprintf(stderr, "ds4: failed to open --dump-logits file: %s\n", cfg->gen.dump_logits_path); + free(logits); + ds4_session_free(session); + return 1; + } + + fprintf(fp, "{\n \"source\":\"ds4\",\n \"model\":"); + json_write_string(fp, cfg->engine.model_path, strlen(cfg->engine.model_path)); + fprintf(fp, + ",\n \"backend\":\"%s\",\n \"mt\":\"%s\",\n \"quant_bits\":%d,\n" + " \"prompt_tokens\":%d,\n \"ctx\":%d,\n \"vocab\":%d,\n", + ds4_backend_name(cfg->engine.backend), + ds4_mpp_mode_name(cfg->engine.mpp_mode), + ds4_engine_routed_quant_bits(engine), + prompt->len, + cfg->gen.ctx_size, + vocab); + const int argmax = ds4_session_argmax(session); + fputs(" \"argmax_token\":", fp); + json_write_token(fp, engine, argmax); + fprintf(fp, ",\n \"argmax_logit\":%.9g,\n \"logits\":[", logits[argmax]); + for (int i = 0; i < vocab; i++) { + if (i) fputc(',', fp); + if ((i % 8) == 0) fputs("\n ", fp); + if (isfinite(logits[i])) { + fprintf(fp, "%.9g", logits[i]); + } else { + fputs("null", fp); + } + } + fputs("\n ]\n}\n", fp); + if (fclose(fp) != 0) { + fprintf(stderr, "ds4: failed to close --dump-logits file: %s\n", cfg->gen.dump_logits_path); + free(logits); + ds4_session_free(session); + return 1; + } + + free(logits); + ds4_session_free(session); + return 0; +} + static int run_logprob_dump(ds4_engine *engine, const cli_config *cfg, const ds4_tokens *prompt) { ds4_session *session = NULL; if (ds4_session_create(&session, engine, cfg->gen.ctx_size) != 0) { @@ -741,6 +825,11 @@ static int run_generation(ds4_engine *engine, const cli_config *cfg) { ds4_tokens_free(&prompt); return rc; } + if (cfg->gen.dump_logits_path) { + rc = run_logits_dump(engine, cfg, &prompt); + ds4_tokens_free(&prompt); + return rc; + } if (cfg->gen.dump_logprobs_path) { rc = run_logprob_dump(engine, cfg, &prompt); ds4_tokens_free(&prompt); @@ -1255,7 +1344,7 @@ static cli_config parse_options(int argc, char **argv) { c.gen.seed = parse_u64(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--quality")) { c.engine.quality = true; - } else if (!strcmp(arg, "--mpp")) { + } else if (!strcmp(arg, "-mt") || !strcmp(arg, "--mt") || !strcmp(arg, "--mpp")) { c.engine.mpp_mode = parse_mpp_mode(need_arg(&i, argc, argv, arg)); } else if (!strcmp(arg, "--dir-steering-file")) { c.engine.directional_steering_file = need_arg(&i, argc, argv, arg); @@ -1277,6 +1366,8 @@ static cli_config parse_options(int argc, char **argv) { c.engine.backend = DS4_BACKEND_CUDA; } else if (!strcmp(arg, "--dump-tokens")) { c.gen.dump_tokens = true; + } else if (!strcmp(arg, "--dump-logits")) { + c.gen.dump_logits_path = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--dump-logprobs")) { c.gen.dump_logprobs_path = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--logprobs-top-k")) { diff --git a/ds4_metal.m b/ds4_metal.m index d7b0a115..092815c4 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -440,7 +440,7 @@ static void ds4_gpu_mpp_compare_drain(const char *finish_label) { const int exceeds_target = (nonfinite != 0 || max_abs > 1.0e-3f || rms > 1.0e-4f); if (ds4_gpu_mpp_compare_verbose() || exceeds_target) { fprintf(stderr, - "ds4: Metal MPP compare route=%s module=%s shape=%llux%llux%llu max_abs=%g rms=%g nonfinite=%d max_index=%llu\n", + "ds4: Metal Tensor compare route=%s module=%s shape=%llux%llux%llu max_abs=%g rms=%g nonfinite=%d max_index=%llu\n", item->route, item->label, (unsigned long long)item->dim0, @@ -450,7 +450,7 @@ static void ds4_gpu_mpp_compare_drain(const char *finish_label) { rms, nonfinite, (unsigned long long)max_index); - fprintf(stderr, "ds4: Metal MPP compare route=%s module=%s largest deltas:", + fprintf(stderr, "ds4: Metal Tensor compare route=%s module=%s largest deltas:", item->route, item->label); for (int j = 0; j < DS4_METAL_MPP_COMPARE_DELTAS && delta_idx[j] != UINT64_MAX; j++) { fprintf(stderr, " idx=%llu ref=%g cand=%g abs=%g", @@ -465,7 +465,7 @@ static void ds4_gpu_mpp_compare_drain(const char *finish_label) { g_mpp_compare_done_count++; if (exceeds_target) { fprintf(stderr, - "ds4: Metal MPP compare route=%s module=%s exceeded target max_abs<=0.001 rms<=0.0001; stopping comparisons\n", + "ds4: Metal Tensor compare route=%s module=%s exceeded target max_abs<=0.001 rms<=0.0001; stopping comparisons\n", item->route, item->label); g_mpp_compare_stopped = 1; @@ -474,7 +474,7 @@ static void ds4_gpu_mpp_compare_drain(const char *finish_label) { if (!g_mpp_compare_stopped && !g_mpp_compare_limit_reported && g_mpp_compare_done_count >= max_reports) { fprintf(stderr, - "ds4: Metal MPP compare reached DS4_METAL_MPP_COMPARE_MAX=%d without a target breach\n", + "ds4: Metal Tensor compare reached DS4_METAL_MPP_COMPARE_MAX=%d without a target breach\n", max_reports); g_mpp_compare_limit_reported = 1; } @@ -1001,7 +1001,7 @@ static int ds4_gpu_env_bool(const char *name) { if (!g_mpp_invalid_env_reported) { fprintf(stderr, - "ds4: invalid Metal MPP boolean environment value %s=%.*s; treating presence as enabled\n", + "ds4: invalid Metal Tensor boolean environment value %s=%.*s; treating presence as enabled\n", name, (int)n, v); g_mpp_invalid_env_reported = 1; } @@ -1028,7 +1028,7 @@ static int ds4_gpu_mpp_low_power_profile(void) { } if (detected && !reported) { fprintf(stderr, - "ds4: Metal low-power MPP profile active; widening Q8_0 prefill route\n"); + "ds4: Metal low-power Tensor profile active; widening Q8_0 prefill route\n"); reported = 1; } return detected; @@ -1091,7 +1091,7 @@ static int ds4_gpu_mpp_fast_profile(void) { } static const char *ds4_gpu_mpp_enabled_reason(void) { - if (g_mpp_mode == DS4_MPP_ON) return " by --mpp on"; + if (g_mpp_mode == DS4_MPP_ON) return " by -mt on"; if (ds4_gpu_mpp_fast_profile()) return " by DS4_METAL_MPP_FAST"; if (ds4_gpu_env_bool("DS4_METAL_MPP_ENABLE") > 0) return " by DS4_METAL_MPP_ENABLE"; return " by default"; @@ -1106,7 +1106,7 @@ static int ds4_gpu_mpp_q8_0_policy_enabled(void) { static int ds4_gpu_use_mpp_q8_0_matmul(void) { const int enabled = ds4_gpu_mpp_q8_0_policy_enabled(); if (enabled && !g_mpp_q8_reported) { - fprintf(stderr, "ds4: Metal MPP Q8_0 prefill matmul enabled%s\n", + fprintf(stderr, "ds4: Metal Tensor Q8_0 prefill matmul enabled%s\n", ds4_gpu_mpp_enabled_reason()); g_mpp_q8_reported = 1; } @@ -1226,14 +1226,6 @@ static int ds4_gpu_mpp_q8_0_late_safe_context(void) { return 0; } -static int ds4_gpu_mpp_q8_0_default_context(uint64_t n_tok) { - if (strstr(g_mpp_compare_context, "attn_q_b") != NULL && - n_tok <= 2048u) { - return 1; - } - return ds4_gpu_mpp_q8_0_late_safe_context(); -} - static int ds4_gpu_mpp_attn_out_late_safe_context(void) { return ds4_gpu_mpp_late_safe_context_range(32); } @@ -1332,13 +1324,14 @@ static int ds4_gpu_mpp_context_matches_filter( } static int ds4_gpu_mpp_q8_0_context_matches_filter(uint64_t n_tok) { + (void)n_tok; const char *filter = getenv("DS4_METAL_MPP_Q8_0_FILTER"); const int filter_set = filter && filter[0]; const int default_match = (ds4_gpu_mpp_fast_profile() || (!filter_set && ds4_gpu_mpp_low_power_profile())) ? 1 - : ds4_gpu_mpp_q8_0_default_context(n_tok); + : ds4_gpu_mpp_q8_0_late_safe_context(); return ds4_gpu_mpp_context_matches_filter("DS4_METAL_MPP_Q8_0_FILTER", default_match, ds4_gpu_mpp_q8_0_late_safe_context()); @@ -1352,7 +1345,7 @@ static int ds4_gpu_can_use_mpp_q8_0_matmul(uint64_t n_tok) { if (!g_mpp_q8_partial_skip_reported) { fprintf(stderr, - "ds4: Metal MPP Q8_0 prefill matmul skipping partial token tiles; " + "ds4: Metal Tensor Q8_0 prefill matmul skipping partial token tiles; " "set DS4_METAL_MPP_Q8_0_PARTIAL_ENABLE=1 to test them\n"); g_mpp_q8_partial_skip_reported = 1; } @@ -1364,7 +1357,7 @@ static int ds4_gpu_use_mpp_f16_compressor_matmul(void) { "DS4_METAL_MPP_F16_ENABLE", "DS4_METAL_MPP_F16_DISABLE"); if (enabled && !g_mpp_f16_reported) { - fprintf(stderr, "ds4: Metal MPP F16 compressor prefill matmul enabled%s\n", + fprintf(stderr, "ds4: Metal Tensor F16 compressor prefill matmul enabled%s\n", ds4_gpu_mpp_enabled_reason()); g_mpp_f16_reported = 1; } @@ -1383,7 +1376,7 @@ static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { default_match, ds4_gpu_mpp_attn_out_late_safe_context()); if (enabled && !g_mpp_attn_out_reported) { - fprintf(stderr, "ds4: Metal MPP attention-output low projection enabled%s\n", + fprintf(stderr, "ds4: Metal Tensor attention-output low projection enabled%s\n", ds4_gpu_mpp_enabled_reason()); g_mpp_attn_out_reported = 1; } @@ -1395,9 +1388,9 @@ static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { DS4_METAL_MOE_MPP_UP = 1 << 1, DS4_METAL_MOE_MPP_DOWN = 1 << 2, - DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 0, - DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 0, - DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 0, + DS4_METAL_MOE_MPP_DEFAULT_GATE_LAYER = 20, + DS4_METAL_MOE_MPP_DEFAULT_UP_LAYER = 20, + DS4_METAL_MOE_MPP_DEFAULT_DOWN_LAYER = 22, DS4_METAL_MOE_MPP_FAST_GATE_LAYER = 0, DS4_METAL_MOE_MPP_FAST_UP_LAYER = 0, DS4_METAL_MOE_MPP_FAST_DOWN_LAYER = 0, @@ -1449,7 +1442,7 @@ static int ds4_gpu_mpp_routed_moe_stage_mask(void) { mask |= DS4_METAL_MOE_MPP_DOWN; } if (mask && !g_mpp_moe_reported) { - fprintf(stderr, "ds4: Metal MPP routed MoE projections enabled%s\n", + fprintf(stderr, "ds4: Metal Tensor routed MoE projections enabled%s\n", ds4_gpu_mpp_enabled_reason()); g_mpp_moe_reported = 1; } @@ -1501,7 +1494,7 @@ static int ds4_gpu_mpp_routed_moe_mask_for_layer(uint32_t layer_index) { gate_fallback); if (!g_mpp_moe_ranges_reported) { fprintf(stderr, - "ds4: Metal MPP routed MoE default ranges down=%d..end up=%d..end gate=%d..end\n", + "ds4: Metal Tensor routed MoE default ranges down=%d..end up=%d..end gate=%d..end\n", down_start, up_start, gate_start); @@ -1535,7 +1528,7 @@ static int ds4_gpu_mpp_routed_moe_mask_for_layer(uint32_t layer_index) { static void ds4_gpu_warn_mpp_fallback(void) { static int warned; if (!warned) { - fprintf(stderr, "ds4: Metal MPP prefill matmul unavailable; falling back to legacy kernel\n"); + fprintf(stderr, "ds4: Metal Tensor prefill matmul unavailable; falling back to legacy kernel\n"); warned = 1; } } @@ -2107,12 +2100,12 @@ void ds4_gpu_print_memory_report(const char *label) { "DS4_METAL_MPP_ATTN_OUT_DISABLE"); const int mpp_moe = ds4_gpu_mpp_routed_moe_stage_mask(); fprintf(stderr, - "ds4: MPP policy %s%s%s\n", + "ds4: Metal Tensor policy %s%s%s\n", ds4_mpp_mode_name(g_mpp_mode), g_quality_mode ? " (disabled by --quality)" : "", !g_metal4_tensor_api_enabled ? " (tensor API unavailable)" : ""); fprintf(stderr, - "ds4: MPP routes q8_0=%s f16_compressor=%s attn_out=%s moe_gate=%s moe_up=%s moe_down=%s\n", + "ds4: Metal Tensor routes q8_0=%s f16_compressor=%s attn_out=%s moe_gate=%s moe_up=%s moe_down=%s\n", mpp_q8 ? "on" : "off", mpp_f16 ? "on" : "off", mpp_attn_out ? "on" : "off", @@ -3781,10 +3774,38 @@ int ds4_gpu_init(void) { return 0; } MTLCompileOptions *options = [MTLCompileOptions new]; + NSMutableDictionary *macros = [NSMutableDictionary new]; if (g_metal4_tensor_api_enabled) { - options.preprocessorMacros = @{ @"DS4_METAL_HAS_TENSOR": @"1" }; - fprintf(stderr, "ds4: Metal 4 tensor API enabled for MPP tensor kernels\n"); + macros[@"DS4_METAL_HAS_TENSOR"] = @"1"; + fprintf(stderr, "ds4: Metal 4 tensor API enabled for Tensor kernels\n"); + } + + const int drift_hc_stable = ds4_gpu_env_bool("DS4_METAL_HC_STABLE") != 0; // default ON + const int drift_norm_unify = ds4_gpu_env_bool("DS4_METAL_NORM_RSQRT_DISABLE") != 0; // default ON + const int drift_kv_raw_f32 = ds4_gpu_env_bool("DS4_METAL_KV_RAW_F32") > 0; // default OFF + const int drift_rope_exp2_log2 = ds4_gpu_env_bool("DS4_METAL_ROPE_EXP2_LOG2") > 0; // default OFF + const int drift_tensor_matmul_off = g_metal4_tensor_api_enabled && + ds4_gpu_env_bool("DS4_METAL_TENSOR_MATMUL_DISABLE") > 0; + + if (drift_hc_stable) macros[@"DS4_METAL_HC_STABLE"] = @"1"; + if (drift_norm_unify) macros[@"DS4_METAL_NORM_RSQRT_DISABLE"] = @"1"; + if (drift_kv_raw_f32) macros[@"DS4_METAL_KV_RAW_F32"] = @"1"; + if (drift_rope_exp2_log2) macros[@"DS4_METAL_ROPE_EXP2_LOG2"] = @"1"; + if (drift_tensor_matmul_off) { + // Recompile without DS4_METAL_HAS_TENSOR so the cooperative-tensor + // matmul branches are excluded from this build, isolating the + // simdgroup_float8x8 path for an A/B vs the Tensor matmul on M5. + [macros removeObjectForKey:@"DS4_METAL_HAS_TENSOR"]; + fprintf(stderr, "ds4: Metal 4 cooperative-tensor matmul disabled by DS4_METAL_TENSOR_MATMUL_DISABLE\n"); } + fprintf(stderr, + "ds4: drift-patch flags hc_stable=%s norm_unify=%s kv_raw_f32=%s rope_exp2_log2=%s tensor_matmul=%s\n", + drift_hc_stable ? "on" : "off", + drift_norm_unify ? "on" : "off", + drift_kv_raw_f32 ? "on" : "off", + drift_rope_exp2_log2 ? "on" : "off", + (g_metal4_tensor_api_enabled && !drift_tensor_matmul_off) ? "on" : "off"); + options.preprocessorMacros = macros; id library = [g_device newLibraryWithSource:source options:options error:&error]; if (!library) { fprintf(stderr, "ds4: Metal shader compilation failed: %s\n", @@ -6259,7 +6280,7 @@ int ds4_gpu_matmul_q8_0_mpp_tensor( if (!xbuf || !outbuf || ds4_gpu_tensor_bytes(x) < x_bytes || ds4_gpu_tensor_bytes(out) < out_bytes) { - fprintf(stderr, "ds4: Metal MPP Q8_0 matmul received undersized activation buffers\n"); + fprintf(stderr, "ds4: Metal Tensor Q8_0 matmul received undersized activation buffers\n"); return 0; } @@ -6267,7 +6288,7 @@ int ds4_gpu_matmul_q8_0_mpp_tensor( const uint64_t row_bytes = blocks * 34; const uint64_t weight_bytes = out_dim * row_bytes; if (weight_offset > model_size || weight_bytes > model_size - weight_offset) { - fprintf(stderr, "ds4: Metal MPP Q8_0 matmul range is outside the mapped model\n"); + fprintf(stderr, "ds4: Metal Tensor Q8_0 matmul range is outside the mapped model\n"); return 0; } @@ -6311,7 +6332,7 @@ int ds4_gpu_matmul_q8_0_mpp_tensor( threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; ds4_gpu_end_compute_encoder(cb, enc); - if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal MPP Q8_0 matmul")) return 0; + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal Tensor Q8_0 matmul")) return 0; } return 1; @@ -6538,7 +6559,7 @@ int ds4_gpu_matmul_f16_tensor( threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; ds4_gpu_end_compute_encoder(cb, enc); - if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal MPP F16 compressor matmul")) return 0; + if (!ds4_gpu_finish_command_buffer(cb, owned, "Metal Tensor F16 compressor matmul")) return 0; return 1; } } @@ -6603,7 +6624,7 @@ int ds4_gpu_matmul_f16_pair_tensor( ds4_gpu_tensor_bytes(x) < x_bytes || ds4_gpu_tensor_bytes(out_a) < out_bytes || ds4_gpu_tensor_bytes(out_b) < out_bytes) { - fprintf(stderr, "ds4: Metal F16 paired MPP matmul received undersized activation buffers\n"); + fprintf(stderr, "ds4: Metal F16 paired Tensor matmul received undersized activation buffers\n"); return 0; } @@ -6611,7 +6632,7 @@ int ds4_gpu_matmul_f16_pair_tensor( const uint64_t weight_bytes = row_bytes * out_dim; if (weight_a_offset > model_size || weight_bytes > model_size - weight_a_offset || weight_b_offset > model_size || weight_bytes > model_size - weight_b_offset) { - fprintf(stderr, "ds4: Metal F16 paired MPP matmul range is outside the mapped model\n"); + fprintf(stderr, "ds4: Metal F16 paired Tensor matmul range is outside the mapped model\n"); return 0; } @@ -6635,7 +6656,7 @@ int ds4_gpu_matmul_f16_pair_tensor( if (!pipeline) return 0; if (!g_mpp_f16_pair_reported) { fprintf(stderr, "ds4: Metal paired F16 compressor matmul enabled%s\n", - use_wide_mpp_pair ? " with MPP wide route" : ""); + use_wide_mpp_pair ? " with Tensor wide route" : ""); g_mpp_f16_pair_reported = 1; } diff --git a/ds4_server.c b/ds4_server.c index 8fcdd627..33c434fd 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -7844,8 +7844,8 @@ static ds4_mpp_mode parse_mpp_mode_arg(const char *s) { if (!strcmp(s, "auto")) return DS4_MPP_AUTO; if (!strcmp(s, "on")) return DS4_MPP_ON; if (!strcmp(s, "off")) return DS4_MPP_OFF; - server_log(DS4_LOG_DEFAULT, "ds4-server: invalid MPP mode: %s", s); - server_log(DS4_LOG_DEFAULT, "ds4-server: valid MPP modes are: auto, on, off"); + server_log(DS4_LOG_DEFAULT, "ds4-server: invalid Metal Tensor mode: %s", s); + server_log(DS4_LOG_DEFAULT, "ds4-server: valid Metal Tensor modes are: auto, on, off"); exit(2); } @@ -7906,9 +7906,10 @@ static void usage(FILE *fp) { " -t, --threads N\n" " CPU helper threads for lightweight host-side work.\n" " --quality\n" - " Prefer exact kernels where faster approximate paths exist; disables Metal 4 MPP routes; MTP uses strict verification.\n" - " --mpp MODE\n" - " Metal 4 MPP policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" + " Prefer exact kernels where faster approximate paths exist; disables Metal Tensor routes; MTP uses strict verification.\n" + " -mt MODE, --mt MODE\n" + " Metal Tensor policy: auto, on, or off. Default: auto. Auto enables validated safe routes; 'on' is a route diagnostic and may change output.\n" + " Legacy alias: --mpp MODE.\n" " --dir-steering-file FILE\n" " Load one f32 direction vector per layer for directional steering.\n" " --dir-steering-ffn F\n" @@ -8031,7 +8032,7 @@ static server_config parse_options(int argc, char **argv) { c.default_tokens = parse_int_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) { c.engine.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg); - } else if (!strcmp(arg, "--mpp")) { + } else if (!strcmp(arg, "-mt") || !strcmp(arg, "--mt") || !strcmp(arg, "--mpp")) { c.engine.mpp_mode = parse_mpp_mode_arg(need_arg(&i, argc, argv, arg)); } else if (!strcmp(arg, "--host")) { c.host = need_arg(&i, argc, argv, arg); diff --git a/speed-bench/compare_bench.py b/speed-bench/compare_bench.py new file mode 100755 index 00000000..034ab193 --- /dev/null +++ b/speed-bench/compare_bench.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +"""Plot two or more ds4-bench CSV runs as a speed comparison chart.""" + +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +REQUIRED_COLUMNS = { + "ctx_tokens", + "prefill_tps", + "gen_tps", +} + + +def read_run(path: Path) -> dict[int, dict[str, float]]: + with path.open(newline="") as fp: + reader = csv.DictReader(fp) + if reader.fieldnames is None: + raise SystemExit(f"{path}: empty CSV") + missing = REQUIRED_COLUMNS - set(reader.fieldnames) + if missing: + raise SystemExit(f"{path}: missing columns: {', '.join(sorted(missing))}") + + rows: dict[int, dict[str, float]] = {} + for row in reader: + ctx = int(row["ctx_tokens"]) + rows[ctx] = { + "prefill_tps": float(row["prefill_tps"]), + "gen_tps": float(row["gen_tps"]), + } + if not rows: + raise SystemExit(f"{path}: no data rows") + return rows + + +def context_label(ctx: int) -> str: + if ctx < 1024: + return f"{ctx / 1024:g}k" + rounded_k = round(ctx / 1024) + if abs(ctx - rounded_k * 1024) <= max(4, ctx * 0.001): + return f"{rounded_k}k" + return f"{ctx / 1024:.1f}k" + + +def annotate_points(ax, xs: list[int], ys: list[float], color: str, dy: float) -> None: + for x, y in zip(xs, ys): + ax.annotate( + f"{y:.1f}", + (x, y), + textcoords="offset points", + xytext=(0, dy), + ha="center", + va="bottom" if dy >= 0 else "top", + fontsize=8, + color=color, + fontweight="medium", + ) + + +def plot_metric( + ax, + xs: list[int], + labels: list[str], + series: list[list[float]], + metric_title: str, + run_labels: list[str], + annotate: bool, +) -> None: + colors = ["#2563eb", "#64748b", "#ea580c", "#16a34a", "#9333ea", "#dc2626"] + markers = ["o", "s", "^", "D", "P", "X"] + + for i, (values, label) in enumerate(zip(series, run_labels)): + color = colors[i % len(colors)] + ax.plot( + xs, + values, + marker=markers[i % len(markers)], + markersize=7, + linewidth=2.4, + color=color, + label=label, + ) + + if len(series) == 2: + ax.fill_between(xs, series[0], series[1], color=colors[1], alpha=0.08) + + ax.set_title(metric_title, fontsize=15, fontweight="bold", pad=12) + ax.set_xlabel("Context Size") + ax.set_ylabel("Tokens/sec") + ax.set_xticks(xs, labels) + ax.grid(True, color="#d1d5db", linewidth=0.9, alpha=0.65) + ax.set_axisbelow(True) + ax.margins(x=0.05, y=0.18) + + for spine in ("top", "right"): + ax.spines[spine].set_visible(False) + ax.spines["left"].set_color("#9ca3af") + ax.spines["bottom"].set_color("#9ca3af") + + if len(series) == 2: + gain_color = "#14532d" + ymin, ymax = ax.get_ylim() + label_y = ymin + (ymax - ymin) * 0.05 + for x, b, a in zip(xs, series[0], series[1]): + gain = ((a / b) - 1.0) * 100.0 if b else 0.0 + ax.annotate( + f"{gain:+.0f}%", + (x, label_y), + ha="center", + va="center", + fontsize=8, + color=gain_color if gain >= 0 else "#991b1b", + bbox={ + "boxstyle": "round,pad=0.24", + "facecolor": "#ecfdf5" if gain >= 0 else "#fef2f2", + "edgecolor": "#bbf7d0" if gain >= 0 else "#fecaca", + "linewidth": 0.8, + }, + ) + + if annotate: + offsets = [-16, 8, 22, 36, 50, 64] + for i, values in enumerate(series): + annotate_points(ax, xs, values, colors[i % len(colors)], offsets[i % len(offsets)]) + + +def default_run_labels(paths: list[Path], args: argparse.Namespace) -> list[str]: + if len(paths) == 2 and not args.labels: + return [args.before_label, args.after_label] + if args.labels: + if len(args.labels) != len(paths): + raise SystemExit("--labels count must match the number of CSV runs") + return args.labels + return [path.stem for path in paths] + + +def build_chart(args: argparse.Namespace) -> None: + if len(args.runs) < 2: + raise SystemExit("provide at least two ds4-bench CSV files") + runs = [read_run(path) for path in args.runs] + run_labels = default_run_labels(args.runs, args) + contexts = sorted(set.intersection(*(set(run) for run in runs))) + if not contexts: + raise SystemExit("the CSV files have no shared ctx_tokens values") + + x_positions = list(range(len(contexts))) + labels = [context_label(ctx) for ctx in contexts] + prefill_series = [[run[ctx]["prefill_tps"] for ctx in contexts] for run in runs] + gen_series = [[run[ctx]["gen_tps"] for ctx in contexts] for run in runs] + + plt.rcParams.update( + { + "figure.facecolor": "#f8fafc", + "axes.facecolor": "#ffffff", + "axes.edgecolor": "#cbd5e1", + "axes.labelcolor": "#111827", + "xtick.color": "#111827", + "ytick.color": "#111827", + "font.family": "DejaVu Sans", + } + ) + + fig, axes = plt.subplots(1, 2, figsize=(15.5, 7), constrained_layout=True) + fig.suptitle(args.title, fontsize=22, fontweight="bold", y=1.04) + + plot_metric( + axes[0], + x_positions, + labels, + prefill_series, + "Prompt Processing Speed", + run_labels, + not args.no_values, + ) + plot_metric( + axes[1], + x_positions, + labels, + gen_series, + "Text Generation Speed", + run_labels, + not args.no_values, + ) + + handles, legend_labels = axes[0].get_legend_handles_labels() + fig.legend( + handles, + legend_labels, + loc="upper center", + bbox_to_anchor=(0.5, 0.98), + ncol=min(len(run_labels), 4), + frameon=True, + fancybox=True, + shadow=False, + facecolor="#ffffff", + edgecolor="#cbd5e1", + ) + + output = args.output + if output.suffix.lower() != ".png": + raise SystemExit(f"{output}: output must be a .png file") + output.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output, dpi=180, bbox_inches="tight", format="png") + plt.close(fig) + + print(f"Wrote {output}") + header = ["ctx"] + for label in run_labels: + safe = label.lower().replace(" ", "_") + header.extend([f"prefill_{safe}", f"gen_{safe}"]) + for label in run_labels[1:]: + safe = label.lower().replace(" ", "_") + base = run_labels[0].lower().replace(" ", "_") + header.extend([f"prefill_gain_{safe}_vs_{base}", f"gen_gain_{safe}_vs_{base}"]) + print(",".join(header)) + for idx, ctx in enumerate(contexts): + row = [str(ctx)] + base_prefill = prefill_series[0][idx] + base_gen = gen_series[0][idx] + for prefill, gen in zip(prefill_series, gen_series): + row.extend([f"{prefill[idx]:.2f}", f"{gen[idx]:.2f}"]) + for prefill, gen in zip(prefill_series[1:], gen_series[1:]): + prefill_gain = ((prefill[idx] / base_prefill) - 1.0) * 100.0 if base_prefill else 0.0 + gen_gain = ((gen[idx] / base_gen) - 1.0) * 100.0 if base_gen else 0.0 + row.extend([f"{prefill_gain:.1f}", f"{gen_gain:.1f}"]) + print(",".join(row)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Create a two-panel comparison chart from ds4-bench CSV files." + ) + parser.add_argument("runs", nargs="+", type=Path, help="ds4-bench CSV files; first is the baseline") + parser.add_argument( + "-o", + "--output", + type=Path, + default=Path("/tmp/ds4-bench-compare.png"), + help="output chart path; must end in .png", + ) + parser.add_argument("--before-label", default="standard kernel") + parser.add_argument("--after-label", default="Metal Tensor") + parser.add_argument("--labels", nargs="+", help="Labels for each CSV run.") + parser.add_argument("--title", default="ds4-bench Speed Comparison") + parser.add_argument("--no-values", action="store_true", help="hide per-point value labels") + return parser.parse_args() + + +if __name__ == "__main__": + build_chart(parse_args()) diff --git a/speed-bench/compare_logit_drift.py b/speed-bench/compare_logit_drift.py new file mode 100644 index 00000000..140d68ee --- /dev/null +++ b/speed-bench/compare_logit_drift.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +"""Compare full-logit dumps produced by ./ds4 --dump-logits. + +Example: + ./ds4 -m q2.gguf --metal -mt off --dump-logits /tmp/q2-off.json \ + --nothink --prompt-file prompt.txt + ./ds4 -m q2.gguf --metal -mt auto --dump-logits /tmp/q2-mt.json \ + --nothink --prompt-file prompt.txt + ./ds4 -m q4.gguf --metal -mt off --dump-logits /tmp/q4-off.json \ + --nothink --prompt-file prompt.txt + python3 speed-bench/compare_logit_drift.py /tmp/q2-off.json \ + /tmp/q2-mt.json /tmp/q4-off.json --labels q2_mt q4_off +""" + +from __future__ import annotations + +import argparse +import json +import math +from heapq import nlargest +from pathlib import Path +from typing import Any + + +def load_dump(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as fp: + data = json.load(fp) + logits_raw = data.get("logits") + if not isinstance(logits_raw, list) or not logits_raw: + raise SystemExit(f"{path}: missing non-empty logits array") + logits = [float("nan") if v is None else float(v) for v in logits_raw] + vocab = int(data.get("vocab", len(logits))) + if vocab != len(logits): + raise SystemExit(f"{path}: vocab={vocab} does not match logits={len(logits)}") + data["logits"] = logits + data["_path"] = str(path) + return data + + +def dump_label(data: dict[str, Any]) -> str: + model = Path(str(data.get("model", data.get("_path", "dump")))).name + quant = data.get("quant_bits", "?") + mt = data.get("mt", "?") + return f"{model}:q{quant}:mt={mt}" + + +def finite_indices(logits: list[float]) -> list[int]: + return [i for i, v in enumerate(logits) if math.isfinite(v)] + + +def topk(logits: list[float], k: int) -> list[int]: + # Match the C test's tie behavior: higher logit first, lower token id first. + return nlargest(k, finite_indices(logits), key=lambda i: (logits[i], -i)) + + +def overlap(a: list[int], b: list[int], k: int) -> int: + return len(set(a[:k]) & set(b[:k])) + + +def rank_delta(ref_top: list[int], cand_top: list[int]) -> int: + cand_rank = {token: i for i, token in enumerate(cand_top)} + worst = 0 + for i, token in enumerate(ref_top): + if token in cand_rank: + worst = max(worst, abs(cand_rank[token] - i)) + return worst + + +def top_union_max_abs( + ref: list[float], + cand: list[float], + ref_top: list[int], + cand_top: list[int], + k: int, +) -> float: + ids = set(ref_top[:k]) | set(cand_top[:k]) + worst = 0.0 + for token in ids: + if math.isfinite(ref[token]) and math.isfinite(cand[token]): + worst = max(worst, abs(cand[token] - ref[token])) + return worst + + +def compare(ref_dump: dict[str, Any], cand_dump: dict[str, Any], top_k: int) -> dict[str, Any]: + ref = ref_dump["logits"] + cand = cand_dump["logits"] + if len(ref) != len(cand): + raise SystemExit( + f"vocab mismatch: {ref_dump['_path']} has {len(ref)}, " + f"{cand_dump['_path']} has {len(cand)}" + ) + + ref_top = topk(ref, top_k) + cand_top = topk(cand, top_k) + sumsq = 0.0 + max_abs = 0.0 + nonfinite = 0 + largest: list[tuple[float, int, float, float]] = [] + for token, (rv, cv) in enumerate(zip(ref, cand)): + if not math.isfinite(rv) or not math.isfinite(cv): + nonfinite += 1 + continue + delta = cv - rv + abs_delta = abs(delta) + sumsq += delta * delta + max_abs = max(max_abs, abs_delta) + if len(largest) < 5: + largest.append((abs_delta, token, rv, cv)) + largest.sort(reverse=True) + elif abs_delta > largest[-1][0]: + largest[-1] = (abs_delta, token, rv, cv) + largest.sort(reverse=True) + + return { + "same_top1": bool(ref_top and cand_top and ref_top[0] == cand_top[0]), + "ref_top1": ref_top[0] if ref_top else None, + "cand_top1": cand_top[0] if cand_top else None, + "top5_overlap": overlap(ref_top, cand_top, min(5, top_k)), + "top20_overlap": overlap(ref_top, cand_top, min(20, top_k)), + "top_k": top_k, + "max_rank_delta": rank_delta(ref_top, cand_top), + "rms": math.sqrt(sumsq / len(ref)), + "max_abs": max_abs, + "top20_max_abs": top_union_max_abs(ref, cand, ref_top, cand_top, min(20, top_k)), + "nonfinite": nonfinite, + "largest_deltas": [ + {"token": token, "ref": rv, "cand": cv, "abs": abs_delta} + for abs_delta, token, rv, cv in largest + ], + } + + +def print_table(rows: list[dict[str, Any]]) -> None: + headers = [ + "candidate", + "same_top1", + "top5", + "top20", + "rank", + "rms", + "max_abs", + "top20_abs", + "nonfinite", + ] + print(" | ".join(headers)) + print(" | ".join("-" * len(h) for h in headers)) + for row in rows: + print( + " | ".join( + [ + row["label"], + "yes" if row["same_top1"] else "no", + f"{row['top5_overlap']}/5", + f"{row['top20_overlap']}/20", + str(row["max_rank_delta"]), + f"{row['rms']:.6g}", + f"{row['max_abs']:.6g}", + f"{row['top20_max_abs']:.6g}", + str(row["nonfinite"]), + ] + ) + ) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Compare ds4 full-logit JSON dumps from --dump-logits." + ) + parser.add_argument("reference", type=Path) + parser.add_argument("candidates", nargs="+", type=Path) + parser.add_argument("--labels", nargs="+", help="Labels for candidate dumps.") + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--json-output", type=Path) + args = parser.parse_args() + + if args.top_k < 20: + raise SystemExit("--top-k must be at least 20") + if args.labels and len(args.labels) != len(args.candidates): + raise SystemExit("--labels count must match candidate count") + + ref = load_dump(args.reference) + candidates = [load_dump(path) for path in args.candidates] + labels = args.labels or [dump_label(data) for data in candidates] + + print(f"reference: {dump_label(ref)}") + print( + "prompt_tokens: " + f"{ref.get('prompt_tokens', '?')} ctx: {ref.get('ctx', '?')} " + f"vocab: {ref.get('vocab', len(ref['logits']))}" + ) + rows = [] + for label, candidate in zip(labels, candidates): + if candidate.get("prompt_tokens") != ref.get("prompt_tokens"): + print( + f"warning: prompt token mismatch for {label}: " + f"ref={ref.get('prompt_tokens')} cand={candidate.get('prompt_tokens')}" + ) + metrics = compare(ref, candidate, args.top_k) + metrics["label"] = label + metrics["path"] = candidate["_path"] + rows.append(metrics) + + print_table(rows) + for row in rows: + print(f"\n{row['label']} largest deltas:") + for delta in row["largest_deltas"]: + print( + " token={token} ref={ref:.9g} cand={cand:.9g} abs={abs:.9g}".format( + **delta + ) + ) + + if args.json_output: + payload = { + "reference": {"path": ref["_path"], "label": dump_label(ref)}, + "rows": rows, + } + with args.json_output.open("w", encoding="utf-8") as fp: + json.dump(payload, fp, indent=2) + fp.write("\n") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/speed-bench/run_metal_tensor_bench.sh b/speed-bench/run_metal_tensor_bench.sh new file mode 100755 index 00000000..2541178f --- /dev/null +++ b/speed-bench/run_metal_tensor_bench.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "${BASH_SOURCE[0]}")/.." + +PROMPT_FILE="${PROMPT_FILE:-speed-bench/promessi_sposi.txt}" +CTX_START="${CTX_START:-512}" +CTX_MAX="${CTX_MAX:-8192}" +STEP_MUL="${STEP_MUL:-2}" +GEN_TOKENS="${GEN_TOKENS:-128}" +OUT_DIR="${OUT_DIR:-/tmp}" +PYTHON="${PYTHON:-python3}" +OPEN_CHART="${OPEN_CHART:-1}" + +mkdir -p "$OUT_DIR" + +QUALITY_CSV="$OUT_DIR/ds4_bench_quality_${GEN_TOKENS}.csv" +STANDARD_CSV="$OUT_DIR/ds4_bench_standard_metal_${GEN_TOKENS}.csv" +TENSOR_CSV="$OUT_DIR/ds4_bench_tensor_metal_${GEN_TOKENS}.csv" +CHART="$OUT_DIR/ds4_bench_standard_quality_tensor_${GEN_TOKENS}.png" + +COMMON_ARGS=( + --prompt-file "$PROMPT_FILE" + --ctx-start "$CTX_START" + --ctx-max "$CTX_MAX" + --step-mul "$STEP_MUL" + --gen-tokens "$GEN_TOKENS" +) + +echo "1/3 Quality Metal -> $QUALITY_CSV" +./ds4-bench --quality "${COMMON_ARGS[@]}" --csv "$QUALITY_CSV" + +echo "2/3 Standard Metal -> $STANDARD_CSV" +DS4_METAL_MPP_DISABLE=1 ./ds4-bench "${COMMON_ARGS[@]}" --csv "$STANDARD_CSV" + +echo "3/3 Tensor Metal -> $TENSOR_CSV" +./ds4-bench "${COMMON_ARGS[@]}" --csv "$TENSOR_CSV" + +echo "Comparing runs -> $CHART" +"$PYTHON" speed-bench/compare_bench.py \ + "$STANDARD_CSV" \ + "$QUALITY_CSV" \ + "$TENSOR_CSV" \ + --labels "Standard Metal" "Quality Metal" "Tensor Metal" \ + --title "ds4-bench: Standard vs Quality vs Tensor (${GEN_TOKENS} generated tokens)" \ + -o "$CHART" + +echo +echo "Wrote:" +echo " $QUALITY_CSV" +echo " $STANDARD_CSV" +echo " $TENSOR_CSV" +echo " $CHART" + +if [[ "$OPEN_CHART" != "0" ]]; then + if command -v open >/dev/null 2>&1; then + open "$CHART" + elif command -v xdg-open >/dev/null 2>&1; then + xdg-open "$CHART" >/dev/null 2>&1 & + else + echo "No opener found; set OPEN_CHART=0 to skip this step." + fi +fi diff --git a/tests/ds4_test.c b/tests/ds4_test.c index 0c9fd1cf..40ddd48f 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -226,7 +226,7 @@ static void test_metal_q8_0_mpp_matmul_case(const char *label, int have_mpp = ds4_gpu_matmul_q8_0_mpp_tensor( out_mpp, weights_raw, weight_alloc, 0, in_dim, out_dim, x, n_tok); if (!have_mpp) { - fprintf(stderr, "ds4-test: skipping MPP Q8_0 matmul %s; Metal 4 tensor API unavailable\n", + fprintf(stderr, "ds4-test: skipping Tensor Q8_0 matmul %s; Metal 4 tensor API unavailable\n", label); free(x_host); free(ref_host); @@ -255,7 +255,7 @@ static void test_metal_q8_0_mpp_matmul_case(const char *label, const float rms = (float)sqrt(sumsq / (double)((uint64_t)n_tok * out_dim)); if (max_abs >= 0.10f) { fprintf(stderr, - "ds4-test: MPP Q8_0 matmul %s in=%u out=%u tok=%u max_abs=%f rms=%f at token=%llu out=%llu ref=%f mpp=%f\n", + "ds4-test: Tensor Q8_0 matmul %s in=%u out=%u tok=%u max_abs=%f rms=%f at token=%llu out=%llu ref=%f tensor=%f\n", label, in_dim, out_dim, n_tok, max_abs, rms, (unsigned long long)(max_index / out_dim), (unsigned long long)(max_index % out_dim), @@ -869,12 +869,12 @@ static test_mpp_eq_result test_compare_mpp_logits(const test_mpp_eq_case *tc, }; fprintf(stderr, - "ds4-test: MPP equivalence %s top1 ref=%d cand=%d top5_overlap=%d/%d overlap=%d/%d max_rank_delta=%d rms=%g max_abs=%g top20_max_abs=%g\n", + "ds4-test: Tensor equivalence %s top1 ref=%d cand=%d top5_overlap=%d/%d overlap=%d/%d max_rank_delta=%d rms=%g max_abs=%g top20_max_abs=%g\n", tc->id, ref_top[0], cand_top[0], top5_overlap, TEST_MPP_EQ_TOP5, overlap, TEST_MPP_EQ_TOPK, max_rank_delta, rms, max_abs, top_abs); - fprintf(stderr, "ds4-test: MPP equivalence %s largest deltas:", tc->id); + fprintf(stderr, "ds4-test: Tensor equivalence %s largest deltas:", tc->id); for (int i = 0; i < TEST_MPP_EQ_DELTAS && delta_ids[i] >= 0; i++) { fprintf(stderr, " id=%d ref=%g cand=%g abs=%g", delta_ids[i], delta_ref[i], delta_cand[i], delta_abs[i]); @@ -997,7 +997,7 @@ static void test_mpp_summary_note_logits(test_mpp_eq_summary *summary, static void test_mpp_summary_print(const test_mpp_eq_summary *summary) { fprintf(stderr, - "ds4-test: MPP summary route=%s cases=%d capture_fail=%d logits_fail=%d greedy_fail=%d top1_mismatch=%d min_top5_overlap=%d/%d min_overlap=%d/%d worst_rank_delta=%d worst_rms=%g worst_max_abs=%g worst_top20_max_abs=%g\n", + "ds4-test: Tensor summary route=%s cases=%d capture_fail=%d logits_fail=%d greedy_fail=%d top1_mismatch=%d min_top5_overlap=%d/%d min_overlap=%d/%d worst_rank_delta=%d worst_rms=%g worst_max_abs=%g worst_top20_max_abs=%g\n", summary->label, summary->cases, summary->capture_failures, @@ -1018,7 +1018,7 @@ static void test_run_mpp_candidate(const char *label, ds4_mpp_mode mode, test_mpp_eq_case *cases, int ncase) { - fprintf(stderr, "ds4-test: MPP equivalence candidate route=%s mode=%s\n", + fprintf(stderr, "ds4-test: Tensor equivalence candidate route=%s mode=%s\n", label, ds4_mpp_mode_name(mode)); test_mpp_eq_summary summary; test_mpp_summary_init(&summary, label); @@ -1045,7 +1045,7 @@ static void test_run_mpp_candidate(const char *label, for (int j = 0; j < tc->ref_gen_len && j < cand_gen_len; j++) { if (cand_gen[j] != tc->ref_gen[j]) { fprintf(stderr, - "ds4-test: MPP equivalence %s greedy token mismatch step=%d ref=%d cand=%d\n", + "ds4-test: Tensor equivalence %s greedy token mismatch step=%d ref=%d cand=%d\n", tc->id, j, tc->ref_gen[j], cand_gen[j]); summary.greedy_failures++; } @@ -1343,7 +1343,7 @@ static const ds4_test_entry test_entries[] = { {"--tool-call-quality", "tool-call-quality", "model emits valid DSML tool calls", test_tool_call_quality}, {"--logprob-vectors", "logprob-vectors", "official API top-logprob vector comparison", test_official_logprob_vectors}, {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_kernel_group}, - {"--metal-mpp-equivalence", "metal-mpp-equivalence", "Metal MPP off/on prompt-logit and greedy equivalence", test_metal_mpp_equivalence}, + {"--metal-mpp-equivalence", "metal-mpp-equivalence", "Metal Tensor off/on prompt-logit and greedy equivalence", test_metal_mpp_equivalence}, #endif {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, }; @@ -1364,9 +1364,9 @@ static void test_print_help(const char *prog) { puts(" DS4_TEST_MODEL=FILE Model path. Default: ds4flash.gguf"); puts(" DS4_TEST_LONG_PROMPT=FILE Rendered long-context story fact prompt."); puts(" DS4_TEST_VECTOR_FILE=FILE Simple official-vector fixture."); - puts(" DS4_TEST_MPP_EQ_CASE=NAME Run only MPP equivalence cases whose id contains NAME."); - puts(" DS4_TEST_MPP_EQ_FORCE_ON=1 Compare --mpp off against forced --mpp on instead of auto."); - puts(" DS4_TEST_MPP_EQ_MATRIX=1 Run auto and isolated forced MPP route rows."); + puts(" DS4_TEST_MPP_EQ_CASE=NAME Run only Tensor equivalence cases whose id contains NAME."); + puts(" DS4_TEST_MPP_EQ_FORCE_ON=1 Compare -mt off against forced -mt on instead of auto."); + puts(" DS4_TEST_MPP_EQ_MATRIX=1 Run auto and isolated forced Tensor route rows."); } static const ds4_test_entry *test_find_entry(const char *arg) { From 670411da4ee94e390408783c188d176bd0e60a0b Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:06:14 +0200 Subject: [PATCH 07/16] Stabilize HC mixer sigmoid behind DS4_METAL_HC_STABLE (default on) The HC=4 and scalar Sinkhorn split paths use 1/(1+exp(-z)) directly, which overflows when z is sufficiently negative (exp(-z) explodes). M5 Max's faster ALU is more likely than M3/M4 to push HC mixer inputs into that regime upstream, so the latent fragility may surface as logprob drift on M5 only. Replaces 1/(1+exp(-z)) with the identity 0.5*tanh(0.5*z) + 0.5 and 2/(1+exp(-z)) with 1 + tanh(0.5*z). Bounded across the full float range. The iter-0 vs iter-1+ epsilon application difference is left intact -- it is mirrored identically in the scalar reference path and appears to be an intentional Sinkhorn warm-up. Gated by DS4_METAL_HC_STABLE so the historical form can be A/B'd. Co-Authored-By: Claude Opus 4.7 (1M context) --- metal/dsv4_hc.metal | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/metal/dsv4_hc.metal b/metal/dsv4_hc.metal index 89cf6c65..49636f54 100644 --- a/metal/dsv4_hc.metal +++ b/metal/dsv4_hc.metal @@ -77,6 +77,24 @@ struct ds4_metal_args_dsv4_hc_expand { int32_t has_add; }; +// Numerically stable sigmoid. The naive form 1/(1+exp(-z)) overflows for large +// negative z (exp(-z) blows up); replacing it with the 0.5*(tanh(z/2)+1) identity +// keeps the value bounded in [0, 1] across the entire float range. Gated by +// DS4_METAL_HC_STABLE so we can A/B vs the historical form on M5 Max where the +// faster ALU is more likely to push HC mixer inputs into the unstable regime. +#ifdef DS4_METAL_HC_STABLE +static inline float ds4_hc_sigmoid(float z) { return 0.5f * tanh(0.5f * z) + 0.5f; } +static inline float4 ds4_hc_sigmoid(float4 z) { return 0.5f * tanh(0.5f * z) + 0.5f; } +// 2 * sigmoid(z) == 1 + tanh(z/2). +static inline float ds4_hc_twice_sigmoid(float z) { return 1.0f + tanh(0.5f * z); } +static inline float4 ds4_hc_twice_sigmoid(float4 z) { return 1.0f + tanh(0.5f * z); } +#else +static inline float ds4_hc_sigmoid(float z) { return 1.0f / (1.0f + exp(-z)); } +static inline float4 ds4_hc_sigmoid(float4 z) { return 1.0f / (1.0f + exp(-z)); } +static inline float ds4_hc_twice_sigmoid(float z) { return 2.0f / (1.0f + exp(-z)); } +static inline float4 ds4_hc_twice_sigmoid(float4 z) { return 2.0f / (1.0f + exp(-z)); } +#endif + // Splits an HC mixer row into pre weights, post gates, and the HC-to-HC // combination matrix. The 4-channel path is specialized because DS4 Flash uses // HC=4 in normal inference, while the scalar fallback keeps diagnostics usable. @@ -109,12 +127,12 @@ kernel void kernel_dsv4_hc_split_sinkhorn( const float4 pre_z = *((device const float4 *) mix) * pre_scale + *((device const float4 *) base); - *((device float4 *) out) = 1.0f / (1.0f + exp(-pre_z)) + epsv; + *((device float4 *) out) = ds4_hc_sigmoid(pre_z) + epsv; const float4 post_z = *((device const float4 *) (mix + 4)) * post_scale + *((device const float4 *) (base + 4)); - *((device float4 *) (out + 4)) = 2.0f / (1.0f + exp(-post_z)); + *((device float4 *) (out + 4)) = ds4_hc_twice_sigmoid(post_z); float4 r0 = *((device const float4 *) (mix + 8)) * comb_scale + @@ -172,13 +190,13 @@ kernel void kernel_dsv4_hc_split_sinkhorn( for (int i = 0; i < HC; ++i) { const float z = mix[i] * pre_scale + base[i]; - out[i] = 1.0f / (1.0f + exp(-z)) + epsv; + out[i] = ds4_hc_sigmoid(z) + epsv; } for (int i = 0; i < HC; ++i) { const int off = HC + i; const float z = mix[off] * post_scale + base[off]; - out[off] = 2.0f / (1.0f + exp(-z)); + out[off] = ds4_hc_twice_sigmoid(z); } float c[HC_MAX*HC_MAX]; From ae34183525cb7f16aad636b0f3f06928fdc53829 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:06:25 +0200 Subject: [PATCH 08/16] Unify RMSNorm scale formula behind DS4_METAL_NORM_RSQRT_DISABLE (default on) kernel_rms_norm_fuse_impl uses 1.0f/sqrt(mean+eps); the fused kernel_dsv4_qkv_rms_norm_f32_4 was using rsqrt(...) for the same value. Apple Silicon's hardware rsqrt has implementation-defined precision and can differ from 1.0f/sqrt by ~1 ULP. Across the 43 layers of DeepSeek V4 Flash that per-layer ULP drift compounds visibly, and the rounding gap between rsqrt and div+sqrt isn't guaranteed to match between M3/M4 and M5 hardware families. Switch the fused QKV norm to 1.0f/sqrt(...) so both norm kernels share a single formula. Gated by DS4_METAL_NORM_RSQRT_DISABLE so the rsqrt path can be A/B'd. Co-Authored-By: Claude Opus 4.7 (1M context) --- metal/norm.metal | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/metal/norm.metal b/metal/norm.metal index 5bc97179..89206704 100644 --- a/metal/norm.metal +++ b/metal/norm.metal @@ -145,7 +145,14 @@ kernel void kernel_dsv4_qkv_rms_norm_f32_4( sumf = shmem_f32[tiisg]; sumf = simd_sum(sumf); +#ifdef DS4_METAL_NORM_RSQRT_DISABLE + // Match the formula used by kernel_rms_norm_fuse_impl above so both RMSNorm + // entry points produce bit-identical scales. Hardware rsqrt() and 1.0f/sqrt() + // can differ by ~1 ULP and that difference compounds across 43 layers. + const float scale = 1.0f / sqrt(sumf / float(n) + args.eps); +#else const float scale = rsqrt(sumf / float(n) + args.eps); +#endif for (int i = tpitg.x; i < n4; i += ntg.x) { y[i] = (x[i] * scale) * w[i]; From 6240bdb38a800a7768ba00f5fa768309af5e331c Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:06:27 +0200 Subject: [PATCH 09/16] Add diagnostic DS4_METAL_KV_RAW_F32 to skip FP16 KV round-trip kernel_dsv4_kv_fp8_store_f32 deliberately writes the raw cache row as (float)((half)q) so its precision matches the half-typed FlashAttention KV buffer the indexer references. With DS4_METAL_KV_RAW_F32 set, the half cast is skipped and the FP8-dequantized FP32 value is written verbatim. This is diagnostic only: enabling it makes the indexer see higher- precision values than FlashAttention, which is a deliberate mismatch that reveals how much drift the FP16 quantization contributes but is not safe to ship. Default off. Co-Authored-By: Claude Opus 4.7 (1M context) --- metal/dsv4_kv.metal | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/metal/dsv4_kv.metal b/metal/dsv4_kv.metal index 89bd7d3a..be760514 100644 --- a/metal/dsv4_kv.metal +++ b/metal/dsv4_kv.metal @@ -167,13 +167,25 @@ kernel void kernel_dsv4_kv_fp8_store_f32( if (off + (int)tid < n_nope) { const float q = dsv4_e4m3fn_dequant(clamp(v / fp8_scale, -448.0f, 448.0f)) * fp8_scale; kv[off + tid] = q; + // Diagnostic only: skip the FP16 round-trip that normally matches the + // half-typed FlashAttention KV buffer's precision. With this enabled the + // indexer will see higher-precision raw values than FlashAttention does, + // which is informative but not a production-ready setting. +#ifdef DS4_METAL_KV_RAW_F32 + raw[off + tid] = q; +#else raw[off + tid] = (float)((half)q); +#endif } threadgroup_barrier(mem_flags::mem_threadgroup); } for (int i = n_nope + tid; i < head_dim; i += 64) { +#ifdef DS4_METAL_KV_RAW_F32 + raw[i] = kv[i]; +#else raw[i] = (float)((half)kv[i]); +#endif } } From a8223179da411c3406c3aa70c05110e0634d239a Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:06:31 +0200 Subject: [PATCH 10/16] Add diagnostic DS4_METAL_ROPE_EXP2_LOG2 RoPE angle path Metal's pow(freq_base, k) is not IEEE-754 strict and the rounding can differ between GPU families. With DS4_METAL_ROPE_EXP2_LOG2 set, the RoPE angle is computed as exp2(k * log2(freq_base)) instead, using two primitives with tighter precision specifications. The change touches both the NeoX and default RoPE branches of kernel_dsv4_rope_tail_f32. Default off -- this is a diagnostic to quantify how much RoPE pow precision contributes to logprob drift on M5 Max relative to M3/M4. Co-Authored-By: Claude Opus 4.7 (1M context) --- metal/dsv4_rope.metal | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/metal/dsv4_rope.metal b/metal/dsv4_rope.metal index aaa6f3d9..b3207561 100644 --- a/metal/dsv4_rope.metal +++ b/metal/dsv4_rope.metal @@ -110,7 +110,13 @@ kernel void kernel_dsv4_rope_tail_f32( const int ic = r; const int rel_i0 = 2*ic; +#ifdef DS4_METAL_ROPE_EXP2_LOG2 + // Equivalent to pow(freq_base, k) but expressed through IEEE-754 + // primitives that have tighter precision guarantees than Metal's pow(). + const float theta = theta_base * exp2(inv_ndims * (float)rel_i0 * log2(args.freq_base)); +#else const float theta = theta_base * pow(args.freq_base, inv_ndims*rel_i0); +#endif const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; float cos_theta; @@ -133,7 +139,11 @@ kernel void kernel_dsv4_rope_tail_f32( } const int ic = r/2; +#ifdef DS4_METAL_ROPE_EXP2_LOG2 + const float theta = theta_base * exp2(inv_ndims * (float)r * log2(args.freq_base)); +#else const float theta = theta_base * pow(args.freq_base, inv_ndims*r); +#endif const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; float cos_theta; From a544c53af2a6e5f1bdd1adee9fb2193e81ce80ed Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:09:16 +0200 Subject: [PATCH 11/16] Fix DS4_METAL_TENSOR_MATMUL_DISABLE host dispatch When the macro un-defines DS4_METAL_HAS_TENSOR at library compile time the cooperative-tensor _mpp kernel templates are no longer in the library, but g_metal4_tensor_api_enabled was still truthy so the host dispatch layer kept attempting to fetch them. The result was a flood of "Metal kernel kernel_mul_mm_*_mpp_* function not found" warnings on the legacy fallback path. Flip g_metal4_tensor_api_enabled = 0 inside the same branch so the host code's ds4_gpu_use_mpp_*() and ds4_gpu_*_mpp_tensor() guards see the disabled state and skip _mpp lookups entirely. Measured on M5 Max with the short reasoning prompt: drift between -mt off and DS4_METAL_TENSOR_MATMUL_DISABLE=1 -mt auto is now exactly zero (rms=0, max_abs=0, max_rank_delta=0), confirming that the M5 Max logprob drift is sourced entirely in the Metal 4 cooperative-tensor matmul codepath and not in HC, norm, RoPE, or KV. Co-Authored-By: Claude Opus 4.7 (1M context) --- ds4_metal.m | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ds4_metal.m b/ds4_metal.m index 092815c4..620eaf40 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -3795,7 +3795,10 @@ int ds4_gpu_init(void) { // Recompile without DS4_METAL_HAS_TENSOR so the cooperative-tensor // matmul branches are excluded from this build, isolating the // simdgroup_float8x8 path for an A/B vs the Tensor matmul on M5. + // Also flip g_metal4_tensor_api_enabled so the host dispatch + // skips _mpp kernel lookups that are no longer compiled. [macros removeObjectForKey:@"DS4_METAL_HAS_TENSOR"]; + g_metal4_tensor_api_enabled = 0; fprintf(stderr, "ds4: Metal 4 cooperative-tensor matmul disabled by DS4_METAL_TENSOR_MATMUL_DISABLE\n"); } fprintf(stderr, From eeed77eda551919e20ebb68ec3085f7b93d0ad50 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:21:58 +0200 Subject: [PATCH 12/16] Default Metal Tensor Q8_0 matmul OFF on M5 Max Bisecting the M5 Max logprob drift on -mt auto: - -mt off baseline: reference - -mt auto (all routes): rms=0.150, max_abs=0.750, top20=0.263 - -mt auto + DS4_METAL_MPP_Q8_0_DISABLE=1: rms=0, max_abs=0 (exact) - -mt auto + DS4_METAL_MPP_F16_DISABLE=1: still rms=0.150 (no help) - -mt auto + DS4_METAL_MPP_ATTN_OUT_DISABLE=1: still rms=0.150 - -mt auto + DS4_METAL_MPP_MOE_{GATE,UP,DOWN}_DISABLE=1: still rms=0.150 The Metal 4 cooperative-tensor Q8_0 matmul (kernel_mul_mm_q8_0_f32_mpp and direct_rhs variants in dense.metal) is the *sole* drift source on M5 Max vs the legacy simdgroup_multiply_accumulate path. The other Tensor routes (F16 compressor, attention-output low projection, routed MoE gate/up/down) are bit-clean against -mt off. Flip ds4_gpu_mpp_q8_0_default_target() to return 0 when the device name contains "M5". Other Tensor routes continue to default on, so the Q8_0 carve-out preserves the bulk of the Metal Tensor speedup (F16 compressor at layers 0-19, MoE at layers 20+, attn-out at layers 32-42). Users who care more about prefill throughput than bit-equivalence can opt back in with DS4_METAL_MPP_Q8_0_ENABLE=1. Verified on M5 Max with default flags only: -mt auto now produces exactly the -mt off logits (rms=0, max_abs=0, max_rank_delta=0, same_top1=yes, top5_overlap=5/5, top20_overlap=20/20). Co-Authored-By: Claude Opus 4.7 (1M context) --- ds4_metal.m | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ds4_metal.m b/ds4_metal.m index 620eaf40..d46104a0 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -965,6 +965,13 @@ static int ds4_gpu_use_compressor_pair_nr4(void) { static int ds4_gpu_device_name_contains(const char *needle); static int ds4_gpu_mpp_q8_0_default_target(void) { + // The Metal 4 cooperative-tensor Q8_0 matmul on M5 Max produces logprob + // drift versus the legacy simdgroup_multiply_accumulate path (measured + // rms=0.150, max_abs=0.75 on the short reasoning prompt; bit-exact match + // recovered by disabling just this route). All other Tensor routes + // (F16 compressor, attention-output, MoE) are bit-clean. Default the + // Q8_0 Tensor matmul to OFF on M5; opt back in with DS4_METAL_MPP_Q8_0_ENABLE=1. + if (ds4_gpu_device_name_contains("M5")) return 0; return 1; } From 2dfac58f404a7fc67028c7c16ca6ad307c7c5e7d Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 10:22:30 +0200 Subject: [PATCH 13/16] Add DS4_METAL_MATH_SAFE diagnostic to pin shader library to IEEE-754 MTLCompileOptions.fastMathEnabled defaults to YES and Apple's headers explicitly note this "may violate the IEEE 754 standard". With safe math forced via MTLMathModeSafe (macOS 15+) or fastMathEnabled=NO (deprecated fallback), drift between -mt off and -mt auto on M5 Max shrinks ~4x (rms 0.150 -> 0.037, max_abs 0.75 -> 0.19) -- showing that fast-math optimizations applied differently across the two hardware paths were amplifying the underlying matmul2d divergence. Default OFF: enabling safe math also moves -mt off away from the fast-math production reference (rms=0.63 vs original fast-math baseline) so it isn't a drop-in fix. Useful as a diagnostic to localize remaining drift sources and as an option for users who prefer strict IEEE-754 semantics over fast-math speed. Co-Authored-By: Claude Opus 4.7 (1M context) --- ds4_metal.m | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/ds4_metal.m b/ds4_metal.m index d46104a0..b32faf2b 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -3791,9 +3791,32 @@ int ds4_gpu_init(void) { const int drift_norm_unify = ds4_gpu_env_bool("DS4_METAL_NORM_RSQRT_DISABLE") != 0; // default ON const int drift_kv_raw_f32 = ds4_gpu_env_bool("DS4_METAL_KV_RAW_F32") > 0; // default OFF const int drift_rope_exp2_log2 = ds4_gpu_env_bool("DS4_METAL_ROPE_EXP2_LOG2") > 0; // default OFF + const int drift_math_safe = ds4_gpu_env_bool("DS4_METAL_MATH_SAFE") > 0; // default OFF const int drift_tensor_matmul_off = g_metal4_tensor_api_enabled && ds4_gpu_env_bool("DS4_METAL_TENSOR_MATMUL_DISABLE") > 0; + if (drift_math_safe) { + // MTLCompileOptions.fastMathEnabled defaults to YES and Apple's + // headers explicitly say this "may violate the IEEE 754 standard". + // Different fast-math optimizations get applied across the + // matmul2d cooperative-tensor path and the legacy + // simdgroup_multiply_accumulate path on M5, amplifying the + // mismatch. MTLMathModeSafe pins the entire library to strict + // IEEE-754 semantics. Diagnostic-only: it also moves the + // -mt off output away from the fast-math reference, so this is + // useful to localize drift sources but not to ship as a default. + if (@available(macOS 15.0, *)) { + options.mathMode = MTLMathModeSafe; + fprintf(stderr, "ds4: Metal shader library math mode = safe (strict IEEE-754) by DS4_METAL_MATH_SAFE\n"); + } else { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + options.fastMathEnabled = NO; +#pragma clang diagnostic pop + fprintf(stderr, "ds4: Metal shader library fast-math disabled by DS4_METAL_MATH_SAFE (pre-macOS 15)\n"); + } + } + if (drift_hc_stable) macros[@"DS4_METAL_HC_STABLE"] = @"1"; if (drift_norm_unify) macros[@"DS4_METAL_NORM_RSQRT_DISABLE"] = @"1"; if (drift_kv_raw_f32) macros[@"DS4_METAL_KV_RAW_F32"] = @"1"; @@ -3809,11 +3832,12 @@ int ds4_gpu_init(void) { fprintf(stderr, "ds4: Metal 4 cooperative-tensor matmul disabled by DS4_METAL_TENSOR_MATMUL_DISABLE\n"); } fprintf(stderr, - "ds4: drift-patch flags hc_stable=%s norm_unify=%s kv_raw_f32=%s rope_exp2_log2=%s tensor_matmul=%s\n", + "ds4: drift-patch flags hc_stable=%s norm_unify=%s kv_raw_f32=%s rope_exp2_log2=%s math_safe=%s tensor_matmul=%s\n", drift_hc_stable ? "on" : "off", drift_norm_unify ? "on" : "off", drift_kv_raw_f32 ? "on" : "off", drift_rope_exp2_log2 ? "on" : "off", + drift_math_safe ? "on" : "off", (g_metal4_tensor_api_enabled && !drift_tensor_matmul_off) ? "on" : "off"); options.preprocessorMacros = macros; id library = [g_device newLibraryWithSource:source options:options error:&error]; From fd7e9fafb32ab92e8d76c393cc36ba2d2e61c766 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 11:28:47 +0200 Subject: [PATCH 14/16] Fix: F16 compressor Tensor matmul incorrectly coupled to Q8 default The previous commit (75f0930) added the M5 carve-out by editing ds4_gpu_mpp_q8_0_default_target(), but that helper was also being reused as the default-target for ds4_gpu_use_mpp_f16_compressor_matmul (line 1363) and for the verbose memory-report banner that prints mpp_f16 (line 2102). That coupled F16 compressor default-on/off to the Q8 carve-out, which is wrong: the per-route bisection showed F16 is bit-clean on M5; only Q8 needed to flip default-off. Introduce a dedicated ds4_gpu_mpp_f16_default_target() that always returns 1 and use it at the two F16 call sites. The Q8 helper keeps its M5 carve-out unchanged. Verified on M5 Max with default flags: -mt auto still produces zero drift vs -mt off (rms=0, max_abs=0, max_rank_delta=0), and the F16 compressor Tensor route is now back to default-on on M5 as intended. Co-Authored-By: Claude Opus 4.7 (1M context) --- ds4_metal.m | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ds4_metal.m b/ds4_metal.m index b32faf2b..c03925fa 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -968,13 +968,21 @@ static int ds4_gpu_mpp_q8_0_default_target(void) { // The Metal 4 cooperative-tensor Q8_0 matmul on M5 Max produces logprob // drift versus the legacy simdgroup_multiply_accumulate path (measured // rms=0.150, max_abs=0.75 on the short reasoning prompt; bit-exact match - // recovered by disabling just this route). All other Tensor routes + // recovered by disabling just this route). The other Tensor routes // (F16 compressor, attention-output, MoE) are bit-clean. Default the // Q8_0 Tensor matmul to OFF on M5; opt back in with DS4_METAL_MPP_Q8_0_ENABLE=1. if (ds4_gpu_device_name_contains("M5")) return 0; return 1; } +// F16 compressor Tensor matmul default. Bit-clean on M5 vs the legacy +// simdgroup path, so this stays default-on independent of device. +// Kept as a separate helper to avoid coupling the F16 default to the +// Q8_0 carve-out above. +static int ds4_gpu_mpp_f16_default_target(void) { + return 1; +} + static int ds4_gpu_env_value_eq(const char *v, size_t n, const char *literal) { size_t m = strlen(literal); if (n != m) return 0; @@ -1360,7 +1368,7 @@ static int ds4_gpu_can_use_mpp_q8_0_matmul(uint64_t n_tok) { } static int ds4_gpu_use_mpp_f16_compressor_matmul(void) { - const int enabled = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_q8_0_default_target(), + const int enabled = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_f16_default_target(), "DS4_METAL_MPP_F16_ENABLE", "DS4_METAL_MPP_F16_DISABLE"); if (enabled && !g_mpp_f16_reported) { @@ -2099,7 +2107,7 @@ void ds4_gpu_print_memory_report(const char *label) { (g_metal4_tensor_api_compile_supported ? "available" : "disabled"), g_metal4_m5_neural_accelerators_hint ? "likely" : "not detected"); const int mpp_q8 = ds4_gpu_mpp_q8_0_policy_enabled(); - const int mpp_f16 = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_q8_0_default_target(), + const int mpp_f16 = ds4_gpu_mpp_route_enabled(ds4_gpu_mpp_f16_default_target(), "DS4_METAL_MPP_F16_ENABLE", "DS4_METAL_MPP_F16_DISABLE"); const int mpp_attn_out = ds4_gpu_mpp_route_enabled(0, From 08de0d464b4ca29683df9747dc7e9c072d74cdac Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 11:30:45 +0200 Subject: [PATCH 15/16] Fix Q8 MPP kernel test: reference must take the legacy path test_metal_q8_0_mpp_matmul_case() built the reference output by calling ds4_gpu_matmul_q8_0_tensor() after ds4_gpu_set_quality(false). The set_quality(false) call enables MPP routing, and the dispatcher at ds4_metal.m:6277 then routes to ds4_gpu_matmul_q8_0_mpp_tensor() when the MPP can_use gate passes. So on M5 with Metal 4 tensor API enabled, the "reference" was actually the MPP output, and the test compared the MPP kernel to itself -- the max_abs/rms numbers were always near zero and any divergence in the MPP kernel itself would not have been caught. Force ds4_gpu_set_quality(true) around the reference call so the dispatcher takes the legacy simdgroup_multiply_accumulate path, then restore set_quality(false) before invoking ds4_gpu_matmul_q8_0_mpp_tensor() directly for the candidate. The reference and candidate now exercise the two different code paths the test was originally meant to compare. Verified on M5 Max: ./ds4_test --metal-kernels still passes, meaning the M5 cooperative-tensor Q8 matmul agrees with the legacy path within the 0.10 max-abs kernel target on the test shapes. The systemic drift in -mt auto comes from many small matmul deltas compounding through 43 layers, not from any single kernel exceeding the per-call threshold. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/ds4_test.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/ds4_test.c b/tests/ds4_test.c index 40ddd48f..23b90563 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -219,9 +219,13 @@ static void test_metal_q8_0_mpp_matmul_case(const char *label, TEST_ASSERT(ds4_gpu_tensor_write(x, 0, x_host, x_bytes) != 0); TEST_ASSERT(ds4_gpu_set_model_map(weights_raw, weight_alloc) != 0); - ds4_gpu_set_quality(false); + // Force quality mode ON so the reference dispatcher takes the legacy + // simdgroup path; otherwise ds4_gpu_matmul_q8_0_tensor() routes to the + // MPP variant on M5+ and the test compares two MPP outputs to each other. + ds4_gpu_set_quality(true); TEST_ASSERT(ds4_gpu_matmul_q8_0_tensor(out_ref, weights_raw, weight_alloc, 0, in_dim, out_dim, x, n_tok) != 0); + ds4_gpu_set_quality(false); int have_mpp = ds4_gpu_matmul_q8_0_mpp_tensor( out_mpp, weights_raw, weight_alloc, 0, in_dim, out_dim, x, n_tok); From 49c1137b815f1961045bb2e5f530aa4b7f2ba67a Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 13 May 2026 11:32:26 +0200 Subject: [PATCH 16/16] Update README to match new M5 Tensor defaults and refreshed drift numbers Two corrections triggered by another reviewer's audit: 1. The auto-suite description claimed "auto enables Q8_0 prefill ..."; on M5 that is no longer true now that 75f0930 defaults Q8_0 Tensor off on M5. Reword the section so it lists F16 compressor, attn-out, and MoE as the auto-enabled routes, then call out the M5 carve-out for Q8_0 explicitly with the env-var opt-in. 2. Refresh worst-case suite numbers measured on the current branch (codex/metal4-m5-drift-patches after the F16-coupling fix 78fa48f and the test-self-reference fix 580e896) on M5 Max: worst_rms = 0.169 (was documented ~= 0.170) worst_top20_max_abs = 0.306 (was documented ~= 0.342) worst_max_abs = 0.922 min_top5_overlap = 5/5 min_top20_overlap = 20/20 (was 19/20) worst_rank_delta = 1 Three short fixtures (short_italian_fact, short_code_completion, short_reasoning_plain) are now bit-exact (rms=0); the residual drift is concentrated on the two long-context fixtures and comes from the F16 compressor, attention-output, and routed-MoE Tensor routes still being default-on, compounding small per-matmul deltas through 43 layers. The Q8_0 isolation paragraph also picks up the M5 default-off note so the env-var docs stay consistent with the runtime behavior. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 33d282c9..b909f58c 100644 --- a/README.md +++ b/README.md @@ -262,9 +262,14 @@ turning on every direct-RHS route at once when the global The Q8_0 prefill Tensor route can be isolated with `DS4_METAL_MPP_Q8_0_ENABLE=1` or `DS4_METAL_MPP_Q8_0_DISABLE=1`. It only -affects prompt batches larger than eight tokens. By default, Q8_0 uses the late -full-model-safe layer window 38..42 plus `attn_q_b` in layers 32..37 for all -prompt batch sizes. It +affects prompt batches larger than eight tokens. **On M5 the Q8_0 Tensor +route is default-off**: bisection on M5 Max showed it was the sole source +of the M5-only `-mt auto` vs `-mt off` logit drift while the other Tensor +routes (F16 compressor, attention-output, MoE) stayed bit-clean on short +prompts. Set `DS4_METAL_MPP_Q8_0_ENABLE=1` to opt back in. On non-M5 +devices Q8_0 stays default-on and uses the late full-model-safe layer +window 38..42 plus `attn_q_b` in layers 32..37 for all prompt batch +sizes. It uses 64-token tiles below 4096-token batches and 32-token tiles for larger prompt batches on M5, accepts partial token tails, and falls back to the legacy kernel when the Metal 4 tensor path is unavailable. When macOS reports Low @@ -304,16 +309,23 @@ shape, max absolute error, RMS, and the largest element deltas. Set `DS4_METAL_MPP_COMPARE_VERBOSE=1` to print passing comparisons as well. Current Tensor route status balances drift with prefill throughput: `auto` enables -Q8_0 prefill, F16 compressor, attention-output low projection, and routed-MoE -Tensor. Attention-output low projection now uses layers 32..42 by default, while -Q8_0 uses the narrower `attn_q_b` 32..37 plus all-Q8 38..42 window by default. -Routed-MoE Tensor now uses the lower-drift conservative default window: -gate/up from layer 20 and down from layer 22. This gives up some of the -all-layer prefill speedup to avoid the larger drift seen with the previous -broader Q8_0 and layer-0 routed-MoE Tensor windows. The current auto suite -reports same-top1/same-greedy agreement with minimum top-5 overlap `5/5`, -minimum top-20 overlap `19/20`, `worst_rms ~= 0.170`, and -`worst_top20_max_abs ~= 0.342`. The Q8_0 and attention-output low Tensor +F16 compressor, attention-output low projection, and routed-MoE Tensor. The +Q8_0 prefill Tensor route is enabled by default on pre-M5 devices and +**default-off on M5**, where bisection traced the entire `-mt auto` vs +`-mt off` drift to that single route; opt back in with +`DS4_METAL_MPP_Q8_0_ENABLE=1`. Attention-output low projection uses layers +32..42 by default, Q8_0 (when enabled) uses the narrower `attn_q_b` 32..37 +plus all-Q8 38..42 window by default, and routed-MoE Tensor uses the +lower-drift conservative default window: gate/up from layer 20 and down +from layer 22. This gives up some of the all-layer prefill speedup to +avoid the larger drift seen with the previous broader Q8_0 and layer-0 +routed-MoE Tensor windows. The current auto suite on M5 reports +same-top1/same-greedy agreement on all five fixtures with minimum top-5 +overlap `5/5`, minimum top-20 overlap `20/20`, `worst_rms ~= 0.169`, and +`worst_top20_max_abs ~= 0.306` (three short fixtures are bit-exact; +residual drift is concentrated on the two long-context fixtures and +comes from the still-enabled F16/attn-out/MoE Tensor routes compounding +through 43 layers). The Q8_0 and attention-output low Tensor kernels stage activation tiles through half to match the legacy Metal matmul input path, which brings the isolated model-ish Q8_0 regression under the strict kernel target and removes the first attention-output comparator breach.