From da44953329daec5f16b7b19d7ebfdc6552415362 Mon Sep 17 00:00:00 2001 From: guyfischman <138163913+guyfischman@users.noreply.github.com> Date: Tue, 12 May 2026 07:15:02 +0200 Subject: [PATCH 1/2] metal : promote mul_mv/mul_mm batch divisors to function constants (#22711) * metal : promote mul_mv/mul_mm batch divisors to function constants * metal : take op directly in get_pipeline_mul_mv_ext --- ggml/src/ggml-metal/ggml-metal-device.cpp | 46 +++++- ggml/src/ggml-metal/ggml-metal-device.h | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 165 +++++++++++----------- 4 files changed, 127 insertions(+), 88 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f14..f0147af84c1 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m return res; } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + const int ne12 = op->src[1]->ne[2]; + const int r2 = ne12 / op->src[0]->ne[2]; + const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3]; + + GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX); + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); - snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); + GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX); + const int16_t ne12 = (int16_t) op->src[1]->ne[2]; + const int16_t ne13 = (int16_t) op->src[1]->ne[3]; + const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]); + const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]); + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d", + base, bc_inp, bc_out, ne12, ne13, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2); + ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; + GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX); + const int16_t r2 = (int16_t) (ne12 / ne02); + const int16_t r3 = (int16_t) (ne13 / ne03); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); - snprintf(name, 256, "%s_nsg=%d", base, nsg); + snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4718ca083b0..1f212a92f98 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c875c..a114391c2e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaedeae..3882b955847 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write( constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; +constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]]; +constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]]; +constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]]; template void mul_vec_q_n_f32_impl( @@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); @@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[NR0]; FOR_UNROLL (int row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; @@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl( device const block_q1_0 * ax[nr0]; for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); } @@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); @@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const T0 * x = (device const T0 *) (src0 + offset0); @@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl( // pointers to src0 rows device const T0 * ax [NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const T0 *) ((device char *) src0 + offset0); } @@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl( device const T0 * ax [NR0]; device const T04 * ax4[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax [row] = (device const T0 *) ((device char *) src0 + offset0); ax4[row] = (device const T04 *) ((device char *) src0 + offset0); @@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl( return; } - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -7462,10 +7465,10 @@ void kernel_mul_mv_q2_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); @@ -7567,10 +7570,10 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); @@ -7741,10 +7744,10 @@ void kernel_mul_mv_q4_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); @@ -7853,10 +7856,10 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); @@ -7989,10 +7992,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); @@ -8094,10 +8097,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); @@ -8202,10 +8205,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); @@ -8321,10 +8324,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); @@ -8433,10 +8436,10 @@ void kernel_mul_mv_iq3_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); @@ -8545,10 +8548,10 @@ void kernel_mul_mv_iq2_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); @@ -8658,10 +8661,10 @@ void kernel_mul_mv_iq1_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); @@ -8757,10 +8760,10 @@ void kernel_mul_mv_iq1_m_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); @@ -8866,10 +8869,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); @@ -8975,10 +8978,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); @@ -9086,10 +9089,10 @@ void kernel_mul_mv_mxfp4_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); @@ -9304,6 +9307,10 @@ kernel void kernel_diag_f32( constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]]; +constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]]; +constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]]; +constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]]; // each block_q contains 16*nl weights #ifdef GGML_METAL_HAS_TENSOR @@ -9330,11 +9337,11 @@ kernel void kernel_mul_mm( // Batch dimension handling const int im = tgpig.z; - const int i12 = im % args.ne12; - const int i13 = im / args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; // Batch offsets for srcA and srcB - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; // Tile dimensions constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; @@ -9473,10 +9480,10 @@ kernel void kernel_mul_mm( short il = il0; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; const short offset1 = il0/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; From 78fbbc2c0788efc8857a2c0dc9802ec689fa12c1 Mon Sep 17 00:00:00 2001 From: Jesus Talavera <145992175+jesus-talavera-ibm@users.noreply.github.com> Date: Tue, 12 May 2026 07:17:04 +0200 Subject: [PATCH 2/2] convert : add split() to LoraTorchTensor in LoRA converter (#22832) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert : add split() method to LoraTorchTensor * Fix python type-check * Fix flake8 Lint * fix: handle positional dim arg in torch.split dispatch * Fix type-check again * Fix type-checks * Remove unit test per reviewers feedback * work around ty deficiency --------- Co-authored-by: Sigbjørn Skjæret --- convert_lora_to_gguf.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index d5833420560..ad4751bb96e 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -188,6 +188,24 @@ def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: return self.transpose(axis0, axis1) + def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]: + shape = self.shape + ndim = len(shape) + if dim < 0: + dim += ndim + if dim == ndim - 1: + A_chunks = self._lora_A.split(split_size, dim=-1) + return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks) + elif dim == ndim - 2: + B_chunks = self._lora_B.split(split_size, dim=-2) + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + else: + B_chunks = self._lora_B.split(split_size, dim=dim) + if self._lora_A.shape[dim] == 1: + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + A_chunks = self._lora_A.split(split_size, dim=dim) + return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks)) + def to(self, *args, **kwargs): return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) @@ -230,6 +248,11 @@ def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): ) else: raise NotImplementedError + elif func is torch.split: + assert len(args) and len(args) >= 2 + tensor, split_size = args[0], args[1] + dim = args[2] if len(args) > 2 else kwargs.get("dim", 0) + return tensor.split(split_size, dim=dim) else: raise NotImplementedError