diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index ef2f202ec9b..4bf6d2bcb03 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -101,10 +101,6 @@ void main() { const uint lane = gl_SubgroupInvocationID; float probs[experts_per_thread]; - [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - probs[i] = -INFINITY; - } [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { @@ -116,9 +112,8 @@ void main() { softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); } else if (gating_func == GATING_FUNC_SIGMOID) { [[unroll]] - for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + lane; - probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY; + for (int i = 0; i < experts_per_thread; i++) { + probs[i] = 1.f / (1.f + exp(-probs[i])); } } @@ -155,11 +150,11 @@ void main() { uint max_expert = lane; [[unroll]] - for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) { - const uint expert = i + lane; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) { - max_val = probs[i / WARP_SIZE]; - max_val_s = selection_probs[i / WARP_SIZE]; + for (int i = 1; i < experts_per_thread; i++) { + const uint expert = lane + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) { + max_val = probs[i]; + max_val_s = selection_probs[i]; max_expert = expert; } } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 84d88e81d45..68244331b14 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -234,6 +234,58 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader( return result; } +/** MulMatVec **/ + +struct ggml_webgpu_mul_mat_vec_shader_lib_context { + ggml_type src0_type; + ggml_type src1_type; + bool vec4; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_vec_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_mul_mat_vec_shader_lib_context & context) { + std::vector defines; + std::string variant = "mul_mat_vec"; + + if (context.src0_type == GGML_TYPE_F32 && context.src1_type == GGML_TYPE_F32) { + if (context.vec4) { + defines.push_back("MUL_MAT_VEC_F32_F32_VEC"); + variant += "_f32_f32_vec"; + } else { + defines.push_back("MUL_MAT_VEC_F32_F32"); + variant += "_f32_f32"; + } + } else if (context.src0_type == GGML_TYPE_F16 && context.src1_type == GGML_TYPE_F32) { + if (context.vec4) { + defines.push_back("MUL_MAT_VEC_F16_F32_VEC"); + variant += "_f16_f32_vec"; + } else { + defines.push_back("MUL_MAT_VEC_F16_F32"); + variant += "_f16_f32"; + } + } else if (context.src0_type == GGML_TYPE_F16 && context.src1_type == GGML_TYPE_F16) { + if (context.vec4) { + defines.push_back("MUL_MAT_VEC_F16_F16_VEC"); + variant += "_f16_f16_vec"; + } else { + defines.push_back("MUL_MAT_VEC_F16_F16"); + variant += "_f16_f16"; + } + } else if (context.src0_type == GGML_TYPE_Q4_0 && context.src1_type == GGML_TYPE_F32 && !context.vec4) { + defines.push_back("MUL_MAT_VEC_Q4_0_F32"); + variant += "_q4_0_f32"; + } else { + GGML_ABORT("Unsupported types for mul_mat_vec shader"); + } + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 22e2bfeb4ce..5c307e296a2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -105,10 +105,10 @@ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 // Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 32 // Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 16 //change back to 64// +//#define WEBGPU_MUL_MAT_VEC_TILE_K 16 /* End Constants */ @@ -1125,15 +1125,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (use_fast) { int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - if (dst->ne[1] == 1) { + //int vectorized = src0->ne[0] % 4 == 0; + if (dst->ne[1] == 1) { // is this a mat vec (n=1) // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); + vectorized = vectorized && (src0->type < 2); // yes pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + + + wg_x = fmin(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + wg_y = CEIL_DIV(total_wg, wg_x); } else { pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; @@ -2602,25 +2605,25 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; - mul_mat_vec_constants[1].key = "TILE_K"; - mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; - mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; - mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); + mul_mat_vec_constants[1].key = "OUTPUTS_PER_WG"; + mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + mul_mat_vec_constants[2].key = "MAX_SUBGROUP_SIZE"; + mul_mat_vec_constants[2].value = webgpu_ctx->global_ctx->capabilities.max_subgroup_size; + + auto build_mul_mat_vec = [&](ggml_type src0_type, ggml_type src1_type, int vec4) { + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_mul_mat_vec_shader( + webgpu_ctx->p, wgsl_mul_mat_vec, { src0_type, src1_type, vec4 != 0 }); + webgpu_ctx->mul_mat_vec_pipelines[src0_type][src1_type][vec4] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str(), mul_mat_vec_constants); + }; + + build_mul_mat_vec(GGML_TYPE_F32, GGML_TYPE_F32, 0); + build_mul_mat_vec(GGML_TYPE_F32, GGML_TYPE_F32, 1); + build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F32, 0); + build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F32, 1); + build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F16, 0); + build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F16, 1); + build_mul_mat_vec(GGML_TYPE_Q4_0, GGML_TYPE_F32, 0); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl index ffbb6403285..05a4aa651db 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl @@ -1,159 +1,127 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); +enable f16; +enable subgroups; + +#if defined(MUL_MAT_VEC_F32_F32_VEC) +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 +#define DST_TYPE vec4 +#define VEC_SIZE 4u +#define VEC +#define MUL_ACC_FLOAT +#elif defined(MUL_MAT_VEC_F32_F32) +#define SRC0_TYPE f32 +#define SRC1_TYPE f32 +#define DST_TYPE f32 +#define VEC_SIZE 1u +#define SCALAR +#define MUL_ACC_FLOAT +#elif defined(MUL_MAT_VEC_F16_F32_VEC) +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 +#define DST_TYPE vec4 +#define VEC_SIZE 4u +#define VEC +#define MUL_ACC_FLOAT +#elif defined(MUL_MAT_VEC_F16_F32) +#define SRC0_TYPE f16 +#define SRC1_TYPE f32 +#define DST_TYPE f32 +#define VEC_SIZE 1u +#define SCALAR +#define MUL_ACC_FLOAT +#elif defined(MUL_MAT_VEC_F16_F16_VEC) +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 +#define DST_TYPE vec4 +#define VEC_SIZE 4u +#define VEC +#define MUL_ACC_FLOAT +#elif defined(MUL_MAT_VEC_F16_F16) +#define SRC0_TYPE f16 +#define SRC1_TYPE f16 +#define DST_TYPE f32 +#define VEC_SIZE 1u +#define SCALAR +#define MUL_ACC_FLOAT +#elif defined(MUL_MAT_VEC_Q4_0_F32) +#define SRC0_TYPE f16 +#define SRC1_TYPE f32 +#define DST_TYPE f32 +#define VEC_SIZE 1u +#define SCALAR +#define BYTE_HELPERS +#define MUL_ACC_Q4_0 +#endif + +#if defined(BYTE_HELPERS) +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8u)) & 0xFFu; } +#endif -fn store_val(group_base: u32) -> vec4 { - return vec4(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); +#if defined(VEC) +fn inner_mul(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> vec4 { + return vec4(src0_val) * vec4(src1_val); } -#enddecl(VEC) -#decl(SCALAR) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(src0_val) * f32(src1_val); +fn reduce_vec4(v: vec4) -> f32 { + return v.x + v.y + v.z + v.w; } -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; -} -#enddecl(SCALAR) +fn store_val(dst_idx: u32, dst: ptr, read_write>, subgroup_invocation_id: u32, subgroup_size: u32, num_subgroups: u32, row_base: u32) { + let lane = subgroup_invocation_id; + for (var row = 0u; row < ROWS_PER_WG && row_base + row + VEC_SIZE - 1u < params.m; row += VEC_SIZE) { + let v0 = select(0.0, partial_sums[(row + 0u) * MAX_SUBGROUP_SIZE + lane], lane < num_subgroups); + let v1 = select(0.0, partial_sums[(row + 1u) * MAX_SUBGROUP_SIZE + lane], lane < num_subgroups); + let v2 = select(0.0, partial_sums[(row + 2u) * MAX_SUBGROUP_SIZE + lane], lane < num_subgroups); + let v3 = select(0.0, partial_sums[(row + 3u) * MAX_SUBGROUP_SIZE + lane], lane < num_subgroups); -#decl(MUL_ACC_FLOAT) + let vec_tot = vec4(subgroupAdd(v0), subgroupAdd(v1), subgroupAdd(v2), subgroupAdd(v3)); -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { - let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; - let b = shared_vector[i / {{VEC_SIZE}}]; - local_sum += inner_dot(a, b); + if (subgroup_invocation_id == 0u) { + (*dst)[(dst_idx + row) / VEC_SIZE] = vec_tot; + } } - return local_sum; } +#endif -#enddecl(MUL_ACC_FLOAT) - -#decl(MUL_ACC_Q4_0) - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; - } - } - } - return local_sum; +#if defined(SCALAR) +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(src0_val) * f32(src1_val); } -#enddecl(MUL_ACC_Q4_0) +#if defined(MUL_ACC_Q4_0) +const BLOCK_SIZE = 32u; +const F16_PER_BLOCK = 9u; // 1 scale + 8 packed words (2 bytes each) -#end(DECLS) +fn q4_0_mul(idx_base: u32, k_idx: u32, b: f32) -> f32 { + let block_idx = idx_base + (k_idx / BLOCK_SIZE); + let scale_idx = block_idx * F16_PER_BLOCK; + let d = f32(src0[scale_idx]); -#define(SHADER) -enable f16; + let block_offset = k_idx % BLOCK_SIZE; + let byte_idx = block_offset % 16u; + let q_word_idx = scale_idx + 1u + (byte_idx / 2u); + let q_word = bitcast(vec2(src0[q_word_idx], src0[q_word_idx])); + let q_byte = get_byte(q_word, byte_idx % 2u); + let q_u = select((q_byte >> 4u) & 0xFu, q_byte & 0xFu, block_offset < 16u); + + return (f32(q_u) - 8.0) * d * b; +} +#endif -DECLS +fn store_val(dst_idx: u32, dst: ptr, read_write>, subgroup_invocation_id: u32, subgroup_size: u32, num_subgroups: u32, row_base: u32) { + let lane = subgroup_invocation_id; + for (var row = 0u; row < ROWS_PER_WG && row_base + row < params.m; row++) { + let v = select(0.0, partial_sums[row * MAX_SUBGROUP_SIZE + lane], lane < num_subgroups); + let tot = subgroupAdd(v); + + if (subgroup_invocation_id == 0u) { + (*dst)[dst_idx + row] = tot; + } + } +} +#endif struct MulMatParams { offset_src0: u32, @@ -174,27 +142,31 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) +@group(0) @binding(0) var src0: array; // Matrix (M x K) +@group(0) @binding(1) var src1: array; // Vector (K x 1, transposed) +@group(0) @binding(2) var dst: array; // Result vector (transposed) @group(0) @binding(3) var params: MulMatParams; override WORKGROUP_SIZE: u32; -override TILE_K: u32; override OUTPUTS_PER_WG: u32; -override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; + +const TILE_K = 4u / VEC_SIZE; // 4 +const ROWS_PER_WG = 16u; // 16 +override MAX_SUBGROUP_SIZE: u32; // Shared memory for collaborative loading and reduction -var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile -var partial_sums: array; // For reduction +var partial_sums:array; // For reduction @compute @workgroup_size(WORKGROUP_SIZE) fn main( @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let thread_id = local_id.x; + @builtin(num_workgroups) num_wg: vec3, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_subgroups) num_subgroups: u32) { // Handle batch dimensions let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; @@ -205,12 +177,12 @@ fn main( return; } - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; + // more portable than local_id.x + let thread_id = subgroup_invocation_id + subgroup_id * subgroup_size; + + let k = params.k / VEC_SIZE; + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; let dst2_stride = params.m * params.n; let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); @@ -221,47 +193,145 @@ fn main( let src02_idx = dst2_idx / params.broadcast2; let src12_idx = dst2_idx; - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + // physical dst index is params.offset_dst - but offset by batch and even further so the output row + let src1_idx_base = (params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12) / VEC_SIZE; + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; + + var tile: array; + var row_indeces: array; + var sumf: array = array(); + + + for (var row = 0u; row < ROWS_PER_WG; row++) { + row_indeces[row] = (params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + (row_base + row) * params.stride_01) / VEC_SIZE; + } - var local_sum = 0.0; + let subgroup_base = subgroup_id * subgroup_size * TILE_K; - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); + let stride = WORKGROUP_SIZE * TILE_K; - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { - shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + var ib = subgroup_base; + + while (ib + stride <= k) { + + // load from B vector + for (var i = 0u; i < TILE_K; i++) { + tile[i] = src1[src1_idx_base + ib + (i * subgroup_size) + subgroup_invocation_id]; } - workgroupBarrier(); + for (var row = 0u; row < ROWS_PER_WG; row++) { + + let my_id = row_indeces[row] + ib + subgroup_invocation_id; + +#if defined(VEC) + var sumqv = vec4(0.0); + + // load from A and register tiled B + for (var i = 0u; i < TILE_K; i++) { + sumqv += inner_mul(src0[my_id + (i * subgroup_size)], tile[i]); + } + sumf[row] += reduce_vec4(sumqv); +#elif defined(SCALAR) + var sumq = 0.0; + + // load from A and register tiled B + for (var i = 0u; i < TILE_K; i++) { +#if defined(MUL_ACC_Q4_0) + let k_idx = ib + (i * subgroup_size) + subgroup_invocation_id; + sumq += q4_0_mul(row_indeces[row], k_idx, tile[i]); +#else + sumq += inner_dot(src0[my_id + (i * subgroup_size)], tile[i]); +#endif + } + sumf[row] += sumq; +#endif + } + ib += stride; + } - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + + // tail + // load from B vector + if (ib < k) { + for (var i = 0u; i < TILE_K; i++) { + let k_idx = ib + (i * subgroup_size) + subgroup_invocation_id; + if (k_idx < k) { + tile[i] = src1[src1_idx_base + k_idx]; + } } + } - workgroupBarrier(); + for (var row = 0u; row < ROWS_PER_WG; row++) { + + let my_id = row_indeces[row] + ib + subgroup_invocation_id; + +#if defined(VEC) + var sumqv = vec4(0.0); + + // load from A and register tiled B + for (var i = 0u; i < TILE_K; i++) { + let k_idx = ib + (i * subgroup_size) + subgroup_invocation_id; + if (k_idx < k) { + sumqv += inner_mul(src0[my_id + (i * subgroup_size)], tile[i]); + } + + } + sumf[row] += reduce_vec4(sumqv); +#elif defined(SCALAR) + var sumq = 0.0; + + // load from A and register tiled B + for (var i = 0u; i < TILE_K; i++) { + let k_idx = ib + (i * subgroup_size) + subgroup_invocation_id; + if (k_idx < k) { +#if defined(MUL_ACC_Q4_0) + sumq += q4_0_mul(row_indeces[row], k_idx, tile[i]); +#else + sumq += inner_dot(src0[my_id + (i * subgroup_size)], tile[i]); +#endif + } + + } + sumf[row] += sumq; +#endif } - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; - workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; + + let fast_subgroup_only = (WORKGROUP_SIZE == subgroup_size); + if (fast_subgroup_only) { +#if defined(VEC) + for (var row = 0u; row < ROWS_PER_WG && row_base + row + VEC_SIZE - 1u < params.m; row += VEC_SIZE) { + let v0 = subgroupAdd(sumf[row + 0u]); + let v1 = subgroupAdd(sumf[row + 1u]); + let v2 = subgroupAdd(sumf[row + 2u]); + let v3 = subgroupAdd(sumf[row + 3u]); + if (subgroup_invocation_id == 0u) { + dst[(dst_idx + row) / VEC_SIZE] = vec4(v0, v1, v2, v3); + } + } +#endif +#if defined(SCALAR) + for (var row = 0u; row < ROWS_PER_WG && row_base + row < params.m; row++) { + let tot = subgroupAdd(sumf[row]); + if (subgroup_invocation_id == 0u) { + dst[dst_idx + row] = tot; + } } - offset = offset / 2; +#endif + } else { + // Subgroup-size-agnostic reduction: + for (var row = 0u; row < ROWS_PER_WG; row++) { + sumf[row] = subgroupAdd(sumf[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[row * MAX_SUBGROUP_SIZE + subgroup_id] = sumf[row]; + } + } + workgroupBarrier(); - } - // Store back to global memory - if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { - dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + if (subgroup_id == 0u) { + store_val(dst_idx, &dst, subgroup_invocation_id, subgroup_size, num_subgroups, row_base); + } } + } -#end(SHADER) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d4c1f525c67..6bb781737e9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6122,19 +6122,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache ggml_set_name(k, "k"); - ggml_tensor * v = nullptr; - if (hsk_padded == 576 && hsv_padded == 512) { - // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes - - // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models - // for more info: - // - https://github.com/ggml-org/llama.cpp/pull/13435 - // - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392 - // - https://github.com/ggml-org/llama.cpp/pull/18986 - v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0); - } else { - v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache - } + ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache ggml_set_name(v, "v"); ggml_tensor * m = nullptr; @@ -8216,8 +8204,8 @@ static std::vector> make_test_cases_eval() { for (int nh : { 4, }) { for (int nr3 : { 1, 3, }) { if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes - for (int nr2 : { 1, 4, 12 }) { - if (nr2 == 12 && hsk != 128) continue; + for (int nr2 : { 1, 4, 16 }) { + if (nr2 == 16 && hsk != 128) continue; //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) { for (int kv : { 113, 512, 1024, }) { if (nr2 != 1 && kv != 512) continue; @@ -8472,9 +8460,6 @@ static std::vector> make_test_cases_perf() { // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); - test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); - test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); - for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) {