Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,6 @@ void main() {
const uint lane = gl_SubgroupInvocationID;

float probs[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
probs[i] = -INFINITY;
}

[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
Expand All @@ -116,9 +112,8 @@ void main() {
softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
} else if (gating_func == GATING_FUNC_SIGMOID) {
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
for (int i = 0; i < experts_per_thread; i++) {
probs[i] = 1.f / (1.f + exp(-probs[i]));
}
}

Expand Down Expand Up @@ -155,11 +150,11 @@ void main() {
uint max_expert = lane;

[[unroll]]
for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
max_val = probs[i / WARP_SIZE];
max_val_s = selection_probs[i / WARP_SIZE];
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = lane + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
max_val = probs[i];
max_val_s = selection_probs[i];
max_expert = expert;
}
}
Expand Down
52 changes: 52 additions & 0 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,58 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
return result;
}

/** MulMatVec **/

struct ggml_webgpu_mul_mat_vec_shader_lib_context {
ggml_type src0_type;
ggml_type src1_type;
bool vec4;
};

inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_vec_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_mul_mat_vec_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";

if (context.src0_type == GGML_TYPE_F32 && context.src1_type == GGML_TYPE_F32) {
if (context.vec4) {
defines.push_back("MUL_MAT_VEC_F32_F32_VEC");
variant += "_f32_f32_vec";
} else {
defines.push_back("MUL_MAT_VEC_F32_F32");
variant += "_f32_f32";
}
} else if (context.src0_type == GGML_TYPE_F16 && context.src1_type == GGML_TYPE_F32) {
if (context.vec4) {
defines.push_back("MUL_MAT_VEC_F16_F32_VEC");
variant += "_f16_f32_vec";
} else {
defines.push_back("MUL_MAT_VEC_F16_F32");
variant += "_f16_f32";
}
} else if (context.src0_type == GGML_TYPE_F16 && context.src1_type == GGML_TYPE_F16) {
if (context.vec4) {
defines.push_back("MUL_MAT_VEC_F16_F16_VEC");
variant += "_f16_f16_vec";
} else {
defines.push_back("MUL_MAT_VEC_F16_F16");
variant += "_f16_f16";
}
} else if (context.src0_type == GGML_TYPE_Q4_0 && context.src1_type == GGML_TYPE_F32 && !context.vec4) {
defines.push_back("MUL_MAT_VEC_Q4_0_F32");
variant += "_q4_0_f32";
} else {
GGML_ABORT("Unsupported types for mul_mat_vec shader");
}

ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}

/** Pad **/

struct ggml_webgpu_pad_pipeline_key {
Expand Down
55 changes: 29 additions & 26 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2

// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 32
// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 16 //change back to 64//
//#define WEBGPU_MUL_MAT_VEC_TILE_K 16

/* End Constants */

Expand Down Expand Up @@ -1125,15 +1125,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,

if (use_fast) {
int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
if (dst->ne[1] == 1) {
//int vectorized = src0->ne[0] % 4 == 0;
if (dst->ne[1] == 1) { // is this a mat vec (n=1)
// We don't support vectorized mul_mat_vec for quantized types
vectorized = vectorized && (src0->type < 2);
vectorized = vectorized && (src0->type < 2); // yes
pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
uint32_t total_wg = output_groups * batches;
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);


wg_x = fmin(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
wg_y = CEIL_DIV(total_wg, wg_x);
} else {
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
uint32_t wg_m;
Expand Down Expand Up @@ -2602,25 +2605,25 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
mul_mat_vec_constants[1].key = "TILE_K";
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;

webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
mul_mat_vec_constants[1].key = "OUTPUTS_PER_WG";
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
mul_mat_vec_constants[2].key = "MAX_SUBGROUP_SIZE";
mul_mat_vec_constants[2].value = webgpu_ctx->global_ctx->capabilities.max_subgroup_size;

auto build_mul_mat_vec = [&](ggml_type src0_type, ggml_type src1_type, int vec4) {
ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_mul_mat_vec_shader(
webgpu_ctx->p, wgsl_mul_mat_vec, { src0_type, src1_type, vec4 != 0 });
webgpu_ctx->mul_mat_vec_pipelines[src0_type][src1_type][vec4] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str(), mul_mat_vec_constants);
};

build_mul_mat_vec(GGML_TYPE_F32, GGML_TYPE_F32, 0);
build_mul_mat_vec(GGML_TYPE_F32, GGML_TYPE_F32, 1);
build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F32, 0);
build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F32, 1);
build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F16, 0);
build_mul_mat_vec(GGML_TYPE_F16, GGML_TYPE_F16, 1);
build_mul_mat_vec(GGML_TYPE_Q4_0, GGML_TYPE_F32, 0);
}

static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
Expand Down
Loading