Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,8 @@ struct ggml_backend_cuda_context {
int device;
std::string name;
cudaEvent_t copy_event = nullptr;
bool disable_mmq_stream_k_default = false;
bool enable_mmq_cp_async_default = false;

cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
Expand Down
30 changes: 28 additions & 2 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2355,6 +2355,14 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);

if (ggml_nelements(src0) == 0 || ggml_nelements(src1) == 0) {
const size_t dst_nbytes = ggml_nbytes(dst);
if (dst_nbytes > 0) {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, dst_nbytes, ctx.stream()));
}
return;
}

// If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
// But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.
// Therefore, in such cases use cuBLAS.
Expand Down Expand Up @@ -2450,6 +2458,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *

const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;

if (ggml_nelements(src0) == 0 || ggml_nelements(src1) == 0) {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), ctx.stream()));
return;
}

// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
Expand Down Expand Up @@ -4718,10 +4731,23 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
};
}

static bool ggml_backend_cuda_params_disable_mmq_stream_k_default(const char * params) {
return params != nullptr && strstr(params, "disable_mmq_stream_k_default=1") != nullptr;
}

static bool ggml_backend_cuda_params_enable_mmq_cp_async_default(const char * params) {
return params != nullptr && strstr(params, "enable_mmq_cp_async_default=1") != nullptr;
}

static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ggml_backend_cuda_init(ctx->device);
ggml_backend_t backend = ggml_backend_cuda_init(ctx->device);
if (backend != nullptr) {
ggml_backend_cuda_context * backend_ctx = (ggml_backend_cuda_context *) backend->context;
backend_ctx->disable_mmq_stream_k_default = ggml_backend_cuda_params_disable_mmq_stream_k_default(params);
backend_ctx->enable_mmq_cp_async_default = ggml_backend_cuda_params_enable_mmq_cp_async_default(params);
}
return backend;
}

static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) {
Expand Down
12 changes: 7 additions & 5 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ void ggml_cuda_mul_mat_q(
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;

const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
const bool use_stream_k_default = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc))
&& !ctx.disable_mmq_stream_k_default;

// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
Expand Down Expand Up @@ -158,7 +159,7 @@ void ggml_cuda_mul_mat_q(
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k, ne1};
use_stream_k_default, ne1};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
return;
}
Expand Down Expand Up @@ -218,7 +219,7 @@ void ggml_cuda_mul_mat_q(
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k, ne12};
use_stream_k_default, ne12};

ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
}
Expand Down Expand Up @@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
const bool use_stream_k = (((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc))
&& !ctx.disable_mmq_stream_k_default)
&& src1_ncols == ne11;
const mmq_args args = {
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
Expand Down
Loading