diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c02a41ad862..4bdc60556fc 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2253,21 +2253,25 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + { + const int64_t gemm_flops = (int64_t)row_diff * src1_ncols * ne10; + const bool use_mkl_direct = gemm_flops < 256 * 256 * 256; #if GGML_SYCL_DNNL - if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, - DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); - } - else + if (!g_ggml_sycl_disable_dnn && !use_mkl_direct) { + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif - { - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } } GGML_UNUSED(dst);