diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3aec1742ee1..c9123b7fd70 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1354,6 +1354,8 @@ struct ggml_backend_cuda_context { int device; std::string name; cudaEvent_t copy_event = nullptr; + bool disable_mmq_stream_k_default = false; + bool enable_mmq_cp_async_default = false; cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de579d2ed50..e45a876fb43 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2355,6 +2355,14 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); + if (ggml_nelements(src0) == 0 || ggml_nelements(src1) == 0) { + const size_t dst_nbytes = ggml_nbytes(dst); + if (dst_nbytes > 0) { + CUDA_CHECK(cudaMemsetAsync(dst->data, 0, dst_nbytes, ctx.stream())); + } + return; + } + // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q. // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data. // Therefore, in such cases use cuBLAS. @@ -2450,6 +2458,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (ggml_nelements(src0) == 0 || ggml_nelements(src1) == 0) { + CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), ctx.stream())); + return; + } + // [TAG_MUL_MAT_ID_CUDA_GRAPHS] if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); @@ -4718,10 +4731,23 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back }; } +static bool ggml_backend_cuda_params_disable_mmq_stream_k_default(const char * params) { + return params != nullptr && strstr(params, "disable_mmq_stream_k_default=1") != nullptr; +} + +static bool ggml_backend_cuda_params_enable_mmq_cp_async_default(const char * params) { + return params != nullptr && strstr(params, "enable_mmq_cp_async_default=1") != nullptr; +} + static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) { - GGML_UNUSED(params); ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; - return ggml_backend_cuda_init(ctx->device); + ggml_backend_t backend = ggml_backend_cuda_init(ctx->device); + if (backend != nullptr) { + ggml_backend_cuda_context * backend_ctx = (ggml_backend_cuda_context *) backend->context; + backend_ctx->disable_mmq_stream_k_default = ggml_backend_cuda_params_disable_mmq_stream_k_default(params); + backend_ctx->enable_mmq_cp_async_default = ggml_backend_cuda_params_enable_mmq_cp_async_default(params); + } + return backend; } static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) { diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3f01ff5bfb0..b132637bcf6 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -118,8 +118,9 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || GGML_CUDA_CC_IS_CDNA(cc); + const bool use_stream_k_default = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc)) + && !ctx.disable_mmq_stream_k_default; // TODO: tighter pool buffer size vs q8 path const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; @@ -158,7 +159,7 @@ void ggml_cuda_mul_mat_q( ne00, ne01, ne1, s01, ne11, s1, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, - use_stream_k, ne1}; + use_stream_k_default, ne1}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); return; } @@ -218,7 +219,7 @@ void ggml_cuda_mul_mat_q( ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, ne02, ne02, s02, s12, s2, ne03, ne13, s03, s13, s3, - use_stream_k, ne12}; + use_stream_k_default, ne12}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); } @@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q( // The stream-k decomposition is only faster for recent NVIDIA GPUs. // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + const bool use_stream_k = (((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_CDNA(cc)) + && !ctx.disable_mmq_stream_k_default) && src1_ncols == ne11; const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b1a319de9be..5648dbb45b2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3,6 +3,7 @@ #include "common.cuh" #include "vecdotq.cuh" #include "mma.cuh" +#include "cp-async.cuh" #include #include @@ -3460,6 +3461,112 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( } } +template +static __device__ __forceinline__ void mul_mat_q_process_tile_cp_async( + const char * __restrict__ x, const int offset_x, const int * __restrict__ y, + const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + + extern __shared__ int data_mul_mat_q[]; + int * tile_y = data_mul_mat_q + mmq_x; + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + +#if defined(BLACKWELL_MMA_AVAILABLE) + // FP4 tile stores 8 blocks + constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; +#else + constexpr int ne_block = 4 * QK8_1; +#endif // defined(BLACKWELL_MMA_AVAILABLE) + + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int); +#if defined(CP_ASYNC_AVAILABLE) + constexpr int tile_y_chunk_elems = 16 / sizeof(int); + constexpr int tile_y_chunks = mmq_x * MMQ_TILE_Y_K / tile_y_chunk_elems; + const unsigned int tile_y_32 = ggml_cuda_cvta_generic_to_shared(tile_y); +#endif + + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { + load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); + { + const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz; +#if defined(CP_ASYNC_AVAILABLE) +#pragma unroll + for (int c0 = 0; c0 < tile_y_chunks; c0 += nwarps * warp_size) { + const int c = c0 + threadIdx.y*warp_size + threadIdx.x; + const int l = c * tile_y_chunk_elems; + cp_async_cg_16<128>(tile_y_32 + l*sizeof(int), by0 + l); + } + cp_async_wait_all(); +#else +#pragma unroll + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } +#endif + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz); +#if defined(CP_ASYNC_AVAILABLE) +#pragma unroll + for (int c0 = 0; c0 < tile_y_chunks; c0 += nwarps * warp_size) { + const int c = c0 + threadIdx.y*warp_size + threadIdx.x; + const int l = c * tile_y_chunk_elems; + cp_async_cg_16<128>(tile_y_32 + l*sizeof(int), by0 + l); + } + cp_async_wait_all(); +#else +#pragma unroll + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } +#endif + } + + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); + + if (kb0 + blocks_per_iter < kb0_stop) { + __syncthreads(); + } + } + + if (fixup) { + write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + } else { + write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); + } +} + // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 @@ -3475,7 +3582,7 @@ template __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -static __global__ void mul_mat_q( +static __global__ void mul_mat_q_base( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, @@ -3716,6 +3823,275 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } +// Extended kernel used only for explicit non-stream-k dispatch on modern CUDA or the MoE cp.async path. +template +#if defined(GGML_USE_HIP) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) +#else + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +static __global__ void mul_mat_q( + const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ncols_max) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { + NO_DEVICE_CODE; + return; + } + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + + // Initialize the ids for writing back data with just the index. + // For regular matrix multiplications this is never changed. + // For MoE the correct indices are loaded from ids_dst. + extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + + if constexpr (!use_stream_k) { + const int wt = blockIdx.z / nchannels_y; + const int zt = blockIdx.z - wt*nchannels_y; + const int jt = blockIdx.y; + const int it = blockIdx.x; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // __syncthreads(); // There is no previous tile that could cause a race condition. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; + if constexpr (use_cp_async) { + mul_mat_q_process_tile_cp_async + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + } else { + mul_mat_q_process_tile + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + } + return; + } + + constexpr int ITER_K = get_iter_k(type); + + const int64_t blocks_per_ne00 = ncols_x / qk; + constexpr int blocks_per_iter = ITER_K / qk; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + + // kb0 == k index when doing the matrix multiplication for an output tile. + int kb0_start = kbc % blocks_per_ne00; + int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + + continue; + } + + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if constexpr (use_cp_async) { + mul_mat_q_process_tile_cp_async + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + } else { + mul_mat_q_process_tile + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + } + + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // The memory layout for the fixup buffer is always contiguous, therefore reset ids: + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + } + + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + if constexpr (use_cp_async) { + mul_mat_q_process_tile_cp_async + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + } else { + mul_mat_q_process_tile + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + } +} + template static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int32_t * expert_bounds, @@ -3909,9 +4285,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); - + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_base), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_base), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; const int ntzw = args.nchannels_y * args.nsamples_y; @@ -3921,24 +4304,94 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); const int channel_ratio = args.nchannels_y / args.nchannels_x; const int sample_ratio = args.nsamples_y / args.nsamples_x; + const bool is_moe = args.ids_dst != nullptr; + const bool use_cp_async = is_moe && ctx.enable_mmq_cp_async_default; + + if (args.use_stream_k && !use_cp_async) { + const dim3 block_nums_stream_k(nsm, 1, 1); + const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + + ggml_cuda_pool & pool = ctx.pool(id); + ggml_cuda_pool_alloc tmp_fixup(pool); + if (fixup_needed) { + tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); + } - if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + mul_mat_q_base<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); } else { constexpr bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + mul_mat_q_base<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } + return; + } + + if (!args.use_stream_k) { + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + constexpr bool use_stream_k = false; + if (use_cp_async) { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } + } else { + constexpr bool need_check = true; + constexpr bool use_stream_k = false; + if (use_cp_async) { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } } return; } @@ -3954,38 +4407,72 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + constexpr bool use_stream_k = true; + if (use_cp_async) { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + if (is_moe) { + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } else { + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } } else { constexpr bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + constexpr bool use_stream_k = true; + if (use_cp_async) { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + if (is_moe) { + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } else { + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } } } @@ -4110,4 +4597,3 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); - diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ee0c29235cd..0a58e6031d4 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -219,8 +219,12 @@ llama_context::llama_context( if (!hparams.vocab_only) { // GPU backends + const bool use_tensor_split_moe_defaults = model.split_mode() == LLAMA_SPLIT_MODE_TENSOR && hparams.n_expert > 0; + const char * backend_params = use_tensor_split_moe_defaults + ? "disable_mmq_stream_k_default=1;enable_mmq_cp_async_default=1" + : nullptr; for (const auto & dev : model.devices) { - ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr); + ggml_backend_t backend = ggml_backend_dev_init(dev.dev, backend_params); if (backend == nullptr) { throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev.dev))); } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 2ff23f87cf4..a3c428a3bac 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2526,9 +2526,10 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); // Clear a single state which will then be copied to the other cleared states. - // Note that this is a no-op when the view is zero-sized. - ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); - ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + if (rs_zero >= 0) { + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size, rs_zero*states->nb[1]); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + } // copy states // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs @@ -2537,11 +2538,13 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_rs) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); + if (n_rs > (uint32_t) n_seqs) { + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); + } return output_states; }